lib/eps/base_estimator.rb in eps-0.3.0 vs lib/eps/base_estimator.rb in eps-0.3.1
- old
+ new
@@ -1,8 +1,11 @@
module Eps
class BaseEstimator
def initialize(data = nil, y = nil, **options)
+ @options = options.dup
+ # TODO better pattern - don't pass most options to train
+ options.delete(:intercept)
train(data, y, **options) if data
end
def predict(data)
singular = data.is_a?(Hash)
@@ -26,26 +29,23 @@
predictions = @evaluator.predict(data)
singular ? predictions.first : predictions
end
- def evaluate(data, y = nil, target: nil)
- data, target = prep_data(data, y, target || @target)
- Eps.metrics(data.label, predict(data))
+ def evaluate(data, y = nil, target: nil, weight: nil)
+ data, target = prep_data(data, y, target || @target, weight)
+ Eps.metrics(data.label, predict(data), weight: data.weight)
end
def to_pmml
- (@pmml ||= generate_pmml).to_xml
+ @pmml ||= PMML.generate(self)
end
- def self.load_pmml(data)
- if data.is_a?(String)
- data = Nokogiri::XML(data) { |config| config.strict }
- end
+ def self.load_pmml(pmml)
model = new
- model.instance_variable_set("@pmml", data) # cache data
- model.instance_variable_set("@evaluator", yield(data))
+ model.instance_variable_set("@evaluator", PMML.load(pmml))
+ model.instance_variable_set("@pmml", pmml.respond_to?(:to_xml) ? pmml.to_xml : pmml) # cache data
model
end
def summary(extended: false)
str = String.new("")
@@ -55,74 +55,35 @@
y_pred = predict(@validation_set)
case @target_type
when "numeric"
metric_name = "RMSE"
- v = Metrics.rmse(y_true, y_pred)
+ v = Metrics.rmse(y_true, y_pred, weight: @validation_set.weight)
metric_value = v.round >= 1000 ? v.round.to_s : "%.3g" % v
else
metric_name = "accuracy"
- metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred)).round(1)
+ metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred, weight: @validation_set.weight)).round(1)
end
str << "Validation %s: %s\n\n" % [metric_name, metric_value]
end
str << _summary(extended: extended)
str
end
- # private
- def self.extract_text_features(data, features)
- # updates features object
- vocabulary = {}
- function_mapping = {}
- derived_fields = {}
- data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
- name = n.attribute("name")&.value
- field = n.css("FieldRef").attribute("field").value
- value = n.css("Constant").text
-
- field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
- next if value.empty?
-
- (vocabulary[field] ||= []) << value
-
- function_mapping[field] = n.css("Apply").attribute("function").value
-
- derived_fields[name] = [field, value]
- end
-
- functions = {}
- data.css("TransformationDictionary DefineFunction").each do |n|
- name = n.attribute("name").value
- text_index = n.css("TextIndex")
- functions[name] = {
- tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
- case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
- }
- end
-
- text_features = {}
- function_mapping.each do |field, function|
- text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
- features[field] = "text"
- end
-
- [text_features, derived_fields]
- end
-
private
- def train(data, y = nil, target: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil)
- data, @target = prep_data(data, y, target)
+ def train(data, y = nil, target: nil, weight: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil)
+ data, @target = prep_data(data, y, target, weight)
@target_type = Utils.column_type(data.label, @target)
if split.nil?
split = data.size >= 30
end
# cross validation
+ # TODO adjust based on weight
if split && !validation_set
split = {} if split == true
split = {column: split} unless split.is_a?(Hash)
split_p = 1 - (split[:validation_size] || 0.25)
@@ -191,12 +152,13 @@
@train_set = data[train_idx]
validation_set = data[validation_idx]
else
@train_set = data.dup
if validation_set
- validation_set = Eps::DataFrame.new(validation_set)
- validation_set.label = validation_set.columns.delete(@target)
+ raise "Target required for validation set" unless target
+ raise "Weight required for validation set" if data.weight && !weight
+ validation_set, _ = prep_data(validation_set, nil, @target, weight)
end
end
raise "No data in training set" if @train_set.empty?
raise "No data in validation set" if validation_set && validation_set.empty?
@@ -208,16 +170,31 @@
@pmml = nil
nil
end
- def prep_data(data, y, target)
+ def prep_data(data, y, target, weight)
data = Eps::DataFrame.new(data)
+
+ # target
target = (target || "target").to_s
y ||= data.columns.delete(target)
check_missing(y, target)
data.label = y.to_a
+
+ # weight
+ if weight
+ weight =
+ if weight.respond_to?(:to_a)
+ weight.to_a
+ else
+ data.columns.delete(weight.to_s)
+ end
+ check_missing(weight, "weight")
+ data.weight = weight.to_a
+ end
+
check_data(data)
[data, target]
end
def prep_text_features(train_set)
@@ -249,10 +226,11 @@
end
def check_data(data)
raise "No data" if data.empty?
raise "Number of data points differs from target" if data.size != data.label.size
+ raise "Number of data points differs from weight" if data.weight && data.size != data.weight.size
end
def check_missing(c, name)
raise ArgumentError, "Missing column: #{name}" if !c
raise ArgumentError, "Missing values in column #{name}" if c.any?(&:nil?)
@@ -271,81 +249,9 @@
else
k.join("=")
end
else
k
- end
- end
-
- # pmml
-
- def build_pmml(data_fields)
- Nokogiri::XML::Builder.new do |xml|
- xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
- pmml_header(xml)
- pmml_data_dictionary(xml, data_fields)
- pmml_transformation_dictionary(xml)
- yield xml
- end
- end
- end
-
- def pmml_header(xml)
- xml.Header do
- xml.Application(name: "Eps", version: Eps::VERSION)
- # xml.Timestamp Time.now.utc.iso8601
- end
- end
-
- def pmml_data_dictionary(xml, data_fields)
- xml.DataDictionary do
- data_fields.each do |k, vs|
- case @features[k]
- when "categorical", nil
- xml.DataField(name: k, optype: "categorical", dataType: "string") do
- vs.map(&:to_s).sort.each do |v|
- xml.Value(value: v)
- end
- end
- when "text"
- xml.DataField(name: k, optype: "categorical", dataType: "string")
- else
- xml.DataField(name: k, optype: "continuous", dataType: "double")
- end
- end
- end
- end
-
- def pmml_transformation_dictionary(xml)
- if @text_features.any?
- xml.TransformationDictionary do
- @text_features.each do |k, text_options|
- xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
- xml.ParameterField(name: "text")
- xml.ParameterField(name: "term")
- xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
- xml.FieldRef(field: "term")
- end
- end
- end
- end
- end
- end
-
- def pmml_local_transformations(xml)
- if @text_features.any?
- xml.LocalTransformations do
- @text_features.each do |k, _|
- @text_encoders[k].vocabulary.each do |v|
- xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
- xml.Apply(function: "#{k}Transform") do
- xml.FieldRef(field: k)
- xml.Constant v
- end
- end
- end
- end
- end
end
end
end
end