lib/eps/naive_bayes.rb in eps-0.3.0 vs lib/eps/naive_bayes.rb in eps-0.3.1

- old
+ new

@@ -1,97 +1,13 @@ module Eps class NaiveBayes < BaseEstimator attr_reader :probabilities def accuracy - Eps::Metrics.accuracy(@train_set.label, predict(@train_set)) + Eps::Metrics.accuracy(@train_set.label, predict(@train_set), weight: @train_set.weight) end - # pmml - - def self.load_pmml(data) - super do |data| - # TODO more validation - node = data.css("NaiveBayesModel") - - prior = {} - node.css("BayesOutput TargetValueCount").each do |n| - prior[n.attribute("value").value] = n.attribute("count").value.to_f - end - - legacy = false - - conditional = {} - features = {} - node.css("BayesInput").each do |n| - prob = {} - - # numeric - n.css("TargetValueStat").each do |n2| - n3 = n2.css("GaussianDistribution") - prob[n2.attribute("value").value] = { - mean: n3.attribute("mean").value.to_f, - stdev: Math.sqrt(n3.attribute("variance").value.to_f) - } - end - - # detect bad form in Eps < 0.3 - bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys - - # categorical - n.css("PairCounts").each do |n2| - if bad_format - n2.css("TargetValueCount").each do |n3| - prob[n3.attribute("value").value] ||= {} - prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value) - end - else - boom = {} - n2.css("TargetValueCount").each do |n3| - boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value) - end - prob[n2.attribute("value").value] = boom - end - end - - if bad_format - legacy = true - prob.each do |k, v| - prior.keys.each do |k| - v[k] ||= 0.0 - end - end - end - - name = n.attribute("fieldName").value - conditional[name] = prob - features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical" - end - - target = node.css("BayesOutput").attribute("fieldName").value - - probabilities = { - prior: prior, - conditional: conditional - } - - # get derived fields - derived = {} - data.css("DerivedField").each do |n| - name = n.attribute("name").value - field = n.css("NormDiscrete").attribute("field").value - value = n.css("NormDiscrete").attribute("value").value - features.delete(name) - features[field] = "derived" - derived[field] ||= {} - derived[field][name] = value - end - - Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy) - end - end - private # TODO better summary def _summary(extended: false) str = String.new("") @@ -103,10 +19,11 @@ def _train(smoothing: 1, **options) raise "Target must be strings" if @target_type != "categorical" check_missing_value(@train_set) check_missing_value(@validation_set) if @validation_set + raise ArgumentError, "weight not supported" if @train_set.weight data = @train_set prep_text_features(data) @@ -181,63 +98,9 @@ prior: prior, conditional: conditional } Evaluators::NaiveBayes.new(probabilities: probabilities, features: @features) - end - - def generate_pmml - data_fields = {} - data_fields[@target] = probabilities[:prior].keys - probabilities[:conditional].each do |k, v| - if @features[k] == "categorical" - data_fields[k] = v.keys - else - data_fields[k] = nil - end - end - - build_pmml(data_fields) do |xml| - xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do - xml.MiningSchema do - data_fields.each do |k, _| - xml.MiningField(name: k) - end - end - xml.BayesInputs do - probabilities[:conditional].each do |k, v| - xml.BayesInput(fieldName: k) do - if @features[k] == "categorical" - v.sort_by { |k2, _| k2 }.each do |k2, v2| - xml.PairCounts(value: k2) do - xml.TargetValueCounts do - v2.sort_by { |k2, _| k2 }.each do |k3, v3| - xml.TargetValueCount(value: k3, count: v3) - end - end - end - end - else - xml.TargetValueStats do - v.sort_by { |k2, _| k2 }.each do |k2, v2| - xml.TargetValueStat(value: k2) do - xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2) - end - end - end - end - end - end - end - xml.BayesOutput(fieldName: "target") do - xml.TargetValueCounts do - probabilities[:prior].sort_by { |k, _| k }.each do |k, v| - xml.TargetValueCount(value: k, count: v) - end - end - end - end - end end def group_count(arr, start) arr.inject(start) { |h, e| h[e] += 1; h } end