lib/eps/naive_bayes.rb in eps-0.3.0 vs lib/eps/naive_bayes.rb in eps-0.3.1
- old
+ new
@@ -1,97 +1,13 @@
module Eps
class NaiveBayes < BaseEstimator
attr_reader :probabilities
def accuracy
- Eps::Metrics.accuracy(@train_set.label, predict(@train_set))
+ Eps::Metrics.accuracy(@train_set.label, predict(@train_set), weight: @train_set.weight)
end
- # pmml
-
- def self.load_pmml(data)
- super do |data|
- # TODO more validation
- node = data.css("NaiveBayesModel")
-
- prior = {}
- node.css("BayesOutput TargetValueCount").each do |n|
- prior[n.attribute("value").value] = n.attribute("count").value.to_f
- end
-
- legacy = false
-
- conditional = {}
- features = {}
- node.css("BayesInput").each do |n|
- prob = {}
-
- # numeric
- n.css("TargetValueStat").each do |n2|
- n3 = n2.css("GaussianDistribution")
- prob[n2.attribute("value").value] = {
- mean: n3.attribute("mean").value.to_f,
- stdev: Math.sqrt(n3.attribute("variance").value.to_f)
- }
- end
-
- # detect bad form in Eps < 0.3
- bad_format = n.css("PairCounts").map { |n2| n2.attribute("value").value } == prior.keys
-
- # categorical
- n.css("PairCounts").each do |n2|
- if bad_format
- n2.css("TargetValueCount").each do |n3|
- prob[n3.attribute("value").value] ||= {}
- prob[n3.attribute("value").value][n2.attribute("value").value] = BigDecimal(n3.attribute("count").value)
- end
- else
- boom = {}
- n2.css("TargetValueCount").each do |n3|
- boom[n3.attribute("value").value] = BigDecimal(n3.attribute("count").value)
- end
- prob[n2.attribute("value").value] = boom
- end
- end
-
- if bad_format
- legacy = true
- prob.each do |k, v|
- prior.keys.each do |k|
- v[k] ||= 0.0
- end
- end
- end
-
- name = n.attribute("fieldName").value
- conditional[name] = prob
- features[name] = n.css("TargetValueStat").any? ? "numeric" : "categorical"
- end
-
- target = node.css("BayesOutput").attribute("fieldName").value
-
- probabilities = {
- prior: prior,
- conditional: conditional
- }
-
- # get derived fields
- derived = {}
- data.css("DerivedField").each do |n|
- name = n.attribute("name").value
- field = n.css("NormDiscrete").attribute("field").value
- value = n.css("NormDiscrete").attribute("value").value
- features.delete(name)
- features[field] = "derived"
- derived[field] ||= {}
- derived[field][name] = value
- end
-
- Evaluators::NaiveBayes.new(probabilities: probabilities, features: features, derived: derived, legacy: legacy)
- end
- end
-
private
# TODO better summary
def _summary(extended: false)
str = String.new("")
@@ -103,10 +19,11 @@
def _train(smoothing: 1, **options)
raise "Target must be strings" if @target_type != "categorical"
check_missing_value(@train_set)
check_missing_value(@validation_set) if @validation_set
+ raise ArgumentError, "weight not supported" if @train_set.weight
data = @train_set
prep_text_features(data)
@@ -181,63 +98,9 @@
prior: prior,
conditional: conditional
}
Evaluators::NaiveBayes.new(probabilities: probabilities, features: @features)
- end
-
- def generate_pmml
- data_fields = {}
- data_fields[@target] = probabilities[:prior].keys
- probabilities[:conditional].each do |k, v|
- if @features[k] == "categorical"
- data_fields[k] = v.keys
- else
- data_fields[k] = nil
- end
- end
-
- build_pmml(data_fields) do |xml|
- xml.NaiveBayesModel(functionName: "classification", threshold: 0.001) do
- xml.MiningSchema do
- data_fields.each do |k, _|
- xml.MiningField(name: k)
- end
- end
- xml.BayesInputs do
- probabilities[:conditional].each do |k, v|
- xml.BayesInput(fieldName: k) do
- if @features[k] == "categorical"
- v.sort_by { |k2, _| k2 }.each do |k2, v2|
- xml.PairCounts(value: k2) do
- xml.TargetValueCounts do
- v2.sort_by { |k2, _| k2 }.each do |k3, v3|
- xml.TargetValueCount(value: k3, count: v3)
- end
- end
- end
- end
- else
- xml.TargetValueStats do
- v.sort_by { |k2, _| k2 }.each do |k2, v2|
- xml.TargetValueStat(value: k2) do
- xml.GaussianDistribution(mean: v2[:mean], variance: v2[:stdev]**2)
- end
- end
- end
- end
- end
- end
- end
- xml.BayesOutput(fieldName: "target") do
- xml.TargetValueCounts do
- probabilities[:prior].sort_by { |k, _| k }.each do |k, v|
- xml.TargetValueCount(value: k, count: v)
- end
- end
- end
- end
- end
end
def group_count(arr, start)
arr.inject(start) { |h, e| h[e] += 1; h }
end