lib/eps/lightgbm.rb in eps-0.3.0 vs lib/eps/lightgbm.rb in eps-0.3.1

- old
+ new

@@ -1,41 +1,7 @@ -require "eps/pmml_generators/lightgbm" - module Eps class LightGBM < BaseEstimator - include PmmlGenerators::LightGBM - - def self.load_pmml(data) - super do |data| - objective = data.css("MiningModel").first.attribute("functionName").value - if objective == "classification" - labels = data.css("RegressionModel OutputField").map { |n| n.attribute("value").value } - objective = labels.size > 2 ? "multiclass" : "binary" - end - - features = {} - text_features, derived_fields = extract_text_features(data, features) - node = data.css("DataDictionary").first - node.css("DataField")[1..-1].to_a.each do |node| - features[node.attribute("name").value] = - if node.attribute("optype").value == "categorical" - "categorical" - else - "numeric" - end - end - - trees = [] - data.css("Segmentation TreeModel").each do |tree| - node = find_nodes(tree.css("Node").first, derived_fields) - trees << node - end - - Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: features, text_features: text_features) - end - end - private def _summary(extended: false) str = String.new("") importance = @booster.feature_importance @@ -49,52 +15,20 @@ end end str end - def self.find_nodes(xml, derived_fields) - score = BigDecimal(xml.attribute("score").value).to_f - - elements = xml.elements - xml_predicate = elements.first - - predicate = - if xml_predicate.name == "True" - nil - elsif xml_predicate.name == "SimpleSetPredicate" - operator = "in" - value = xml_predicate.css("Array").text.scan(/"(.+?)(?<!\\)"|(\S+)/).flatten.compact.map { |v| v.gsub('\"', '"') } - field = xml_predicate.attribute("field").value - field = derived_fields[field] if derived_fields[field] - { - field: field, - operator: operator, - value: value - } - else - operator = xml_predicate.attribute("operator").value - value = xml_predicate.attribute("value").value - value = BigDecimal(value).to_f if operator == "greaterThan" - field = xml_predicate.attribute("field").value - field = derived_fields[field] if derived_fields[field] - { - field: field, - operator: operator, - value: value - } - end - - children = elements[1..-1].map { |n| find_nodes(n, derived_fields) } - - Evaluators::Node.new(score: score, predicate: predicate, children: children) - end - def _train(verbose: nil, early_stopping: nil) train_set = @train_set validation_set = @validation_set.dup summary_label = train_set.label + # create check set + evaluator_set = validation_set || train_set + check_idx = 100.times.map { rand(evaluator_set.size) }.uniq + evaluator_set = evaluator_set[check_idx] + # objective objective = if @target_type == "numeric" "regression" else @@ -133,12 +67,12 @@ params[:min_data_in_leaf] = 1 end # create datasets categorical_idx = @features.values.map.with_index.select { |type, _| type == "categorical" }.map(&:last) - train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, categorical_feature: categorical_idx, params: params) - validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set + train_ds = ::LightGBM::Dataset.new(train_set.map_rows(&:to_a), label: train_set.label, weight: train_set.weight, categorical_feature: categorical_idx, params: params) + validation_ds = ::LightGBM::Dataset.new(validation_set.map_rows(&:to_a), label: validation_set.label, weight: validation_set.weight, categorical_feature: categorical_idx, params: params, reference: train_ds) if validation_set # train valid_sets = [train_ds] valid_sets << validation_ds if validation_ds valid_names = ["training"] @@ -174,14 +108,40 @@ @booster = booster # reset pmml @pmml = nil - Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features) + evaluator = Evaluators::LightGBM.new(trees: trees, objective: objective, labels: labels, features: @features, text_features: @text_features) + booster_set = validation_set ? validation_set[check_idx] : train_set[check_idx] + check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set) + evaluator end - def evaluator_class - PmmlLoaders::LightGBM + # compare a subset of predictions to check for possible bugs in evaluator + # NOTE LightGBM must use double data type for prediction input for these to be consistent + def check_evaluator(objective, labels, booster, booster_set, evaluator, evaluator_set) + expected = @booster.predict(booster_set.map_rows(&:to_a)) + if objective == "multiclass" + expected.map! do |v| + labels[v.map.with_index.max_by { |v2, _| v2 }.last] + end + elsif objective == "binary" + expected.map! { |v| labels[v >= 0.5 ? 1 : 0] } + end + actual = evaluator.predict(evaluator_set) + + regression = objective == "regression" + bad_observations = [] + expected.zip(actual).each_with_index do |(exp, act), i| + success = regression ? (act - exp).abs < 0.001 : act == exp + unless success + bad_observations << {expected: exp, actual: act, data_point: evaluator_set[i].map(&:itself).first} + end + end + + if bad_observations.any? + raise "Bug detected in evaluator. Please report an issue. Bad data points: #{bad_observations.inspect}" + end end # for evaluator def parse_tree(node)