lib/eps/lightgbm.rb in eps-0.3.0 vs lib/eps/lightgbm.rb in eps-0.3.1
- old
+ new
@@ -1,41 +1,7 @@
-require "eps/pmml_generators/lightgbm"
-
module Eps
class LightGBM < BaseEstimator
- include PmmlGenerators::LightGBM
-
- def self.load_pmml(data)
- super do |data|
- objective = data.css("MiningModel").first.attribute("functionName").value
- if objective == "classification"
- labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value }
- objective = labels.size > 2 ? "multiclass" : "binary"
- end
-
- features = {}
- text_features, derived_fields = extract_text_features(data, features)
- node = data.css("DataDictionary").first
- node.css("DataField")[1..-1].to_a.each do |node|
- features[node.attribute("name").value] =
- if node.attribute("optype").value == "categorical"
- "categorical"
- else
- "numeric"
- end
- end
-
- trees = []
- data.css("Segmentation TreeModel").each do |tree|
- node = find_nodes(tree.css("Node").first, derived_fields)
- trees << node
- end
-
- Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features)
- end
- end
-
private
def _summary(extended: false)
str = String.new("")
importance = @booster.feature_importance
@@ -49,52 +15,20 @@
end
end
str
end
- def self.find_nodes(xml, derived_fields)
- score = BigDecimal(xml.attribute("score").value).to_f
-
- elements = xml.elements
- xml_predicate = elements.first
-
- predicate =
- if xml_predicate.name == "True"
- nil
- elsif xml_predicate.name == "SimpleSetPredicate"
- operator = "in"
- value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') }
- field = xml_predicate.attribute("field").value
- field = derived_fields[field] if derived_fields[field]
- {
- field: field,
- operator: operator,
- value: value
- }
- else
- operator = xml_predicate.attribute("operator").value
- value = xml_predicate.attribute("value").value
- value = BigDecimal(value).to_f if operator == "greaterThan"
- field = xml_predicate.attribute("field").value
- field = derived_fields[field] if derived_fields[field]
- {
- field: field,
- operator: operator,
- value: value
- }
- end
-
- children = elements[1..-1].map { |n| find_nodes(n, derived_fields) }
-
- Evaluators::Node.new(score: score, predicate: predicate, children: children)
- end
-
def _train(verbose: nil, early_stopping: nil)
train_set = @train_set
validation_set = @validation_set.dup
summary_label = train_set.label
+ # create check set
+ evaluator_set = validation_set || train_set
+ check_idx = 100.times.map { rand(evaluator_set.size) }.uniq
+ evaluator_set = evaluator_set[check_idx]
+
# objective
objective =
if @target_type == "numeric"
"regression"
else
@@ -133,12 +67,12 @@
params[:min_data_in_leaf] = 1
end
# create datasets
categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last)
- train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, categorical_feature: categorical_idx, params: params)
- validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
+ train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, weight: train_set.weight, categorical_feature: categorical_idx, params: params)
+ validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, weight: validation_set.weight, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set
# train
valid_sets = [train_ds]
valid_sets << validation_ds if validation_ds
valid_names = ["training"]
@@ -174,14 +108,40 @@
@booster = booster
# reset pmml
@pmml = nil
- Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
+ evaluator = Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features)
+ booster_set = validation_set ? validation_set[check_idx] : train_set[check_idx]
+ check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
+ evaluator
end
- def evaluator_class
- PmmlLoaders::LightGBM
+ # compare a subset of predictions to check for possible bugs in evaluator
+ # NOTE LightGBM must use double data type for prediction input for these to be consistent
+ def check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set)
+ expected = @booster.predict(booster_set.map_rows(&:to_a))
+ if objective == "multiclass"
+ expected.map! do |v|
+ labels[v.map.with_index.max_by { |v2, _| v2 }.last]
+ end
+ elsif objective == "binary"
+ expected.map! { |v| labels[v >= 0.5 ? 1 : 0] }
+ end
+ actual = evaluator.predict(evaluator_set)
+
+ regression = objective == "regression"
+ bad_observations = []
+ expected.zip(actual).each_with_index do |(exp, act), i|
+ success = regression ? (act - exp).abs < 0.001 : act == exp
+ unless success
+ bad_observations << {expected: exp, actual: act, data_point: evaluator_set[i].map(&:itself).first}
+ end
+ end
+
+ if bad_observations.any?
+ raise "Bug detected in evaluator. Please report an issue. Bad data points: #{bad_observations.inspect}"
+ end
end
# for evaluator
def parse_tree(node)