lib/eps/base_estimator.rb in eps-0.3.0 vs lib/eps/base_estimator.rb in eps-0.3.1

- old
+ new

@@ -1,8 +1,11 @@ module Eps class BaseEstimator def initialize(data = nil, y = nil, **options) + @options = options.dup + # TODO better pattern - don't pass most options to train + options.delete(:intercept) train(data, y, **options) if data end def predict(data) singular = data.is_a?(Hash) @@ -26,26 +29,23 @@ predictions = @evaluator.predict(data) singular ? predictions.first : predictions end - def evaluate(data, y = nil, target: nil) - data, target = prep_data(data, y, target || @target) - Eps.metrics(data.label, predict(data)) + def evaluate(data, y = nil, target: nil, weight: nil) + data, target = prep_data(data, y, target || @target, weight) + Eps.metrics(data.label, predict(data), weight: data.weight) end def to_pmml - (@pmml ||= generate_pmml).to_xml + @pmml ||= PMML.generate(self) end - def self.load_pmml(data) - if data.is_a?(String) - data = Nokogiri::XML(data) { |config| config.strict } - end + def self.load_pmml(pmml) model = new - model.instance_variable_set("@pmml", data) # cache data - model.instance_variable_set("@evaluator", yield(data)) + model.instance_variable_set("@evaluator", PMML.load(pmml)) + model.instance_variable_set("@pmml", pmml.respond_to?(:to_xml) ? pmml.to_xml : pmml) # cache data model end def summary(extended: false) str = String.new("") @@ -55,74 +55,35 @@ y_pred = predict(@validation_set) case @target_type when "numeric" metric_name = "RMSE" - v = Metrics.rmse(y_true, y_pred) + v = Metrics.rmse(y_true, y_pred, weight: @validation_set.weight) metric_value = v.round >= 1000 ? v.round.to_s : "%.3g" % v else metric_name = "accuracy" - metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred)).round(1) + metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred, weight: @validation_set.weight)).round(1) end str << "Validation %s: %s\n\n" % [metric_name, metric_value] end str << _summary(extended: extended) str end - # private - def self.extract_text_features(data, features) - # updates features object - vocabulary = {} - function_mapping = {} - derived_fields = {} - data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n| - name = n.attribute("name")&.value - field = n.css("FieldRef").attribute("field").value - value = n.css("Constant").text - - field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/ - next if value.empty? - - (vocabulary[field] ||= []) << value - - function_mapping[field] = n.css("Apply").attribute("function").value - - derived_fields[name] = [field, value] - end - - functions = {} - data.css("TransformationDictionary DefineFunction").each do |n| - name = n.attribute("name").value - text_index = n.css("TextIndex") - functions[name] = { - tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value), - case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true" - } - end - - text_features = {} - function_mapping.each do |field, function| - text_features[field] = functions[function].merge(vocabulary: vocabulary[field]) - features[field] = "text" - end - - [text_features, derived_fields] - end - private - def train(data, y = nil, target: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil) - data, @target = prep_data(data, y, target) + def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil) + data, @target = prep_data(data, y, target, weight) @target_type = Utils.column_type(data.label, @target) if split.nil? split = data.size >= 30 end # cross validation + # TODO adjust based on weight if split && !validation_set split = {} if split == true split = {column: split} unless split.is_a?(Hash) split_p = 1 - (split[:validation_size] || 0.25) @@ -191,12 +152,13 @@ @train_set = data[train_idx] validation_set = data[validation_idx] else @train_set = data.dup if validation_set - validation_set = Eps::DataFrame.new(validation_set) - validation_set.label = validation_set.columns.delete(@target) + raise "Target required for validation set" unless target + raise "Weight required for validation set" if data.weight && !weight + validation_set, _ = prep_data(validation_set, nil, @target, weight) end end raise "No data in training set" if @train_set.empty? raise "No data in validation set" if validation_set && validation_set.empty? @@ -208,16 +170,31 @@ @pmml = nil nil end - def prep_data(data, y, target) + def prep_data(data, y, target, weight) data = Eps::DataFrame.new(data) + + # target target = (target || "target").to_s y ||= data.columns.delete(target) check_missing(y, target) data.label = y.to_a + + # weight + if weight + weight = + if weight.respond_to?(:to_a) + weight.to_a + else + data.columns.delete(weight.to_s) + end + check_missing(weight, "weight") + data.weight = weight.to_a + end + check_data(data) [data, target] end def prep_text_features(train_set) @@ -249,10 +226,11 @@ end def check_data(data) raise "No data" if data.empty? raise "Number of data points differs from target" if data.size != data.label.size + raise "Number of data points differs from weight" if data.weight && data.size != data.weight.size end def check_missing(c, name) raise ArgumentError, "Missing column: #{name}" if !c raise ArgumentError, "Missing values in column #{name}" if c.any?(&:nil?) @@ -271,81 +249,9 @@ else k.join("=") end else k - end - end - - # pmml - - def build_pmml(data_fields) - Nokogiri::XML::Builder.new do |xml| - xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do - pmml_header(xml) - pmml_data_dictionary(xml, data_fields) - pmml_transformation_dictionary(xml) - yield xml - end - end - end - - def pmml_header(xml) - xml.Header do - xml.Application(name: "Eps", version: Eps::VERSION) - # xml.Timestamp Time.now.utc.iso8601 - end - end - - def pmml_data_dictionary(xml, data_fields) - xml.DataDictionary do - data_fields.each do |k, vs| - case @features[k] - when "categorical", nil - xml.DataField(name: k, optype: "categorical", dataType: "string") do - vs.map(&:to_s).sort.each do |v| - xml.Value(value: v) - end - end - when "text" - xml.DataField(name: k, optype: "categorical", dataType: "string") - else - xml.DataField(name: k, optype: "continuous", dataType: "double") - end - end - end - end - - def pmml_transformation_dictionary(xml) - if @text_features.any? - xml.TransformationDictionary do - @text_features.each do |k, text_options| - xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do - xml.ParameterField(name: "text") - xml.ParameterField(name: "term") - xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do - xml.FieldRef(field: "term") - end - end - end - end - end - end - - def pmml_local_transformations(xml) - if @text_features.any? - xml.LocalTransformations do - @text_features.each do |k, _| - @text_encoders[k].vocabulary.each do |v| - xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do - xml.Apply(function: "#{k}Transform") do - xml.FieldRef(field: k) - xml.Constant v - end - end - end - end - end end end end end