lib/eps/base_estimator.rb in eps-0.2.1 vs lib/eps/base_estimator.rb in eps-0.3.0
- old
+ new
@@ -1,82 +1,351 @@
module Eps
class BaseEstimator
- def train(data, y, target: nil, **options)
- # TODO more performant conversion
- if daru?(data)
- x = data.dup
- x = x.delete_vector(target) if target
- else
- x = data.map(&:dup)
- x.each { |r| r.delete(target) } if target
- end
+ def initialize(data = nil, y = nil, **options)
+ train(data, y, **options) if data
+ end
- y = prep_y(y.to_a)
+ def predict(data)
+ singular = data.is_a?(Hash)
+ data = [data] if singular
- if x.size != y.size
- raise "Number of samples differs from target"
+ data = Eps::DataFrame.new(data)
+
+ @evaluator.features.each do |k, type|
+ values = data.columns[k]
+ raise ArgumentError, "Missing column: #{k}" if !values
+ column_type = Utils.column_type(values.compact, k) if values
+
+ if !column_type.nil?
+ if (type == "numeric" && column_type != "numeric") || (type != "numeric" && column_type != "categorical")
+ raise ArgumentError, "Bad type for column #{k}: Expected #{type} but got #{column_type}"
+ end
+ end
+ # TODO check for unknown values for categorical features
end
- @x = x
- @y = y
- @target = target || "target"
+ predictions = @evaluator.predict(data)
+
+ singular ? predictions.first : predictions
end
- def predict(x)
- singular = !(x.is_a?(Array) || daru?(x))
- x = [x] if singular
+ def evaluate(data, y = nil, target: nil)
+ data, target = prep_data(data, y, target || @target)
+ Eps.metrics(data.label, predict(data))
+ end
- pred = _predict(x)
+ def to_pmml
+ (@pmml ||= generate_pmml).to_xml
+ end
- singular ? pred[0] : pred
+ def self.load_pmml(data)
+ if data.is_a?(String)
+ data = Nokogiri::XML(data) { |config| config.strict }
+ end
+ model = new
+ model.instance_variable_set("@pmml", data) # cache data
+ model.instance_variable_set("@evaluator", yield(data))
+ model
end
- def evaluate(data, y = nil, target: nil)
- target ||= @target
- raise ArgumentError, "missing target" if !target && !y
+ def summary(extended: false)
+ str = String.new("")
- actual = y
- actual ||=
- if daru?(data)
- data[target].to_a
+ if @validation_set
+ y_true = @validation_set.label
+ y_pred = predict(@validation_set)
+
+ case @target_type
+ when "numeric"
+ metric_name = "RMSE"
+ v = Metrics.rmse(y_true, y_pred)
+ metric_value = v.round >= 1000 ? v.round.to_s : "%.3g" % v
else
- data.map { |v| v[target] }
+ metric_name = "accuracy"
+ metric_value = "%.1f%%" % (100 * Metrics.accuracy(y_true, y_pred)).round(1)
end
+ str << "Validation %s: %s\n\n" % [metric_name, metric_value]
+ end
- actual = prep_y(actual)
- estimated = predict(data)
+ str << _summary(extended: extended)
+ str
+ end
- self.class.metrics(actual, estimated)
+ # private
+ def self.extract_text_features(data, features)
+ # updates features object
+ vocabulary = {}
+ function_mapping = {}
+ derived_fields = {}
+ data.css("LocalTransformations DerivedField, TransformationDictionary DerivedField").each do |n|
+ name = n.attribute("name")&.value
+ field = n.css("FieldRef").attribute("field").value
+ value = n.css("Constant").text
+
+ field = field[10..-2] if field =~ /\Alowercase\(.+\)\z/
+ next if value.empty?
+
+ (vocabulary[field] ||= []) << value
+
+ function_mapping[field] = n.css("Apply").attribute("function").value
+
+ derived_fields[name] = [field, value]
+ end
+
+ functions = {}
+ data.css("TransformationDictionary DefineFunction").each do |n|
+ name = n.attribute("name").value
+ text_index = n.css("TextIndex")
+ functions[name] = {
+ tokenizer: Regexp.new(text_index.attribute("wordSeparatorCharacterRE").value),
+ case_sensitive: text_index.attribute("isCaseSensitive")&.value == "true"
+ }
+ end
+
+ text_features = {}
+ function_mapping.each do |field, function|
+ text_features[field] = functions[function].merge(vocabulary: vocabulary[field])
+ features[field] = "text"
+ end
+
+ [text_features, derived_fields]
end
private
- def categorical?(v)
- !v.is_a?(Numeric)
+ def train(data, y = nil, target: nil, split: nil, validation_set: nil, verbose: nil, text_features: nil, early_stopping: nil)
+ data, @target = prep_data(data, y, target)
+ @target_type = Utils.column_type(data.label, @target)
+
+ if split.nil?
+ split = data.size >= 30
+ end
+
+ # cross validation
+ if split && !validation_set
+ split = {} if split == true
+ split = {column: split} unless split.is_a?(Hash)
+
+ split_p = 1 - (split[:validation_size] || 0.25)
+ if split[:column]
+ split_column = split[:column].to_s
+ times = data.columns.delete(split_column)
+ check_missing(times, split_column)
+ split_index = (times.size * split_p).round
+ split_time = split[:value] || times.sort[split_index]
+ train_idx, validation_idx = (0...data.size).to_a.partition { |i| times[i] < split_time }
+ else
+ if split[:shuffle] != false
+ rng = Random.new(0) # seed random number generator
+ train_idx, validation_idx = (0...data.size).to_a.partition { rng.rand < split_p }
+ else
+ split_index = (data.size * split_p).round
+ train_idx, validation_idx = (0...data.size).to_a.partition { |i| i < split_index }
+ end
+ end
+ end
+
+ # determine feature types
+ @features = {}
+ data.columns.each do |k, v|
+ @features[k] = Utils.column_type(v.compact, k)
+ end
+
+ # determine text features if not specified
+ if text_features.nil?
+ text_features = []
+
+ @features.each do |k, type|
+ next if type != "categorical"
+
+ values = data.columns[k].compact
+
+ next unless values.first.is_a?(String) # not boolean
+
+ values = values.reject(&:empty?)
+ count = values.count
+
+ # check if spaces
+ # two spaces is rough approximation for 3 words
+ # TODO make more performant
+ if values.count { |v| v.count(" ") >= 2 } > 0.5 * count
+ text_features << k
+ end
+ end
+ end
+
+ # prep text features
+ @text_features = {}
+ (text_features || {}).each do |k, v|
+ @features[k.to_s] = "text"
+
+ # same output as scikit-learn CountVectorizer
+ # except for max_features
+ @text_features[k.to_s] = {
+ tokenizer: /\W+/,
+ min_length: 2,
+ max_features: 100
+ }.merge(v || {})
+ end
+
+ if split && !validation_set
+ @train_set = data[train_idx]
+ validation_set = data[validation_idx]
+ else
+ @train_set = data.dup
+ if validation_set
+ validation_set = Eps::DataFrame.new(validation_set)
+ validation_set.label = validation_set.columns.delete(@target)
+ end
+ end
+
+ raise "No data in training set" if @train_set.empty?
+ raise "No data in validation set" if validation_set && validation_set.empty?
+
+ @validation_set = validation_set
+ @evaluator = _train(verbose: verbose, early_stopping: early_stopping)
+
+ # reset pmml
+ @pmml = nil
+
+ nil
end
- def daru?(x)
- defined?(Daru) && x.is_a?(Daru::DataFrame)
+ def prep_data(data, y, target)
+ data = Eps::DataFrame.new(data)
+ target = (target || "target").to_s
+ y ||= data.columns.delete(target)
+ check_missing(y, target)
+ data.label = y.to_a
+ check_data(data)
+ [data, target]
end
- def flip_target(target)
- target.is_a?(String) ? target.to_sym : target.to_s
+ def prep_text_features(train_set)
+ @text_encoders = {}
+ @text_features.each do |k, v|
+ # reset vocabulary
+ v.delete(:vocabulary)
+
+ # TODO determine max features automatically
+ # start based on number of rows
+ encoder = Eps::TextEncoder.new(v)
+ counts = encoder.fit(train_set.columns.delete(k))
+ encoder.vocabulary.each do |word|
+ train_set.columns[[k, word]] = [0] * counts.size
+ end
+ counts.each_with_index do |ci, i|
+ ci.each do |word, count|
+ word_key = [k, word]
+ train_set.columns[word_key][i] = 1 if train_set.columns.key?(word_key)
+ end
+ end
+ @text_encoders[k] = encoder
+
+ # update vocabulary
+ v[:vocabulary] = encoder.vocabulary
+ end
+
+ raise "No features left" if train_set.columns.empty?
end
- def prep_y(y)
- y.each do |yi|
- raise "Target missing in data" if yi.nil?
+ def check_data(data)
+ raise "No data" if data.empty?
+ raise "Number of data points differs from target" if data.size != data.label.size
+ end
+
+ def check_missing(c, name)
+ raise ArgumentError, "Missing column: #{name}" if !c
+ raise ArgumentError, "Missing values in column #{name}" if c.any?(&:nil?)
+ end
+
+ def check_missing_value(df)
+ df.columns.each do |k, v|
+ check_missing(v, k)
end
- y
end
- # determine if target is a string or symbol
- def prep_target(target, data)
- if daru?(data)
- data.has_vector?(target) ? target : flip_target(target)
+ def display_field(k)
+ if k.is_a?(Array)
+ if @features[k.first] == "text"
+ "#{k.first}(#{k.last})"
+ else
+ k.join("=")
+ end
else
- x = data[0] || {}
- x[target] ? target : flip_target(target)
+ k
+ end
+ end
+
+ # pmml
+
+ def build_pmml(data_fields)
+ Nokogiri::XML::Builder.new do |xml|
+ xml.PMML(version: "4.4", xmlns: "http://www.dmg.org/PMML-4_4", "xmlns:xsi" => "http://www.w3.org/2001/XMLSchema-instance") do
+ pmml_header(xml)
+ pmml_data_dictionary(xml, data_fields)
+ pmml_transformation_dictionary(xml)
+ yield xml
+ end
+ end
+ end
+
+ def pmml_header(xml)
+ xml.Header do
+ xml.Application(name: "Eps", version: Eps::VERSION)
+ # xml.Timestamp Time.now.utc.iso8601
+ end
+ end
+
+ def pmml_data_dictionary(xml, data_fields)
+ xml.DataDictionary do
+ data_fields.each do |k, vs|
+ case @features[k]
+ when "categorical", nil
+ xml.DataField(name: k, optype: "categorical", dataType: "string") do
+ vs.map(&:to_s).sort.each do |v|
+ xml.Value(value: v)
+ end
+ end
+ when "text"
+ xml.DataField(name: k, optype: "categorical", dataType: "string")
+ else
+ xml.DataField(name: k, optype: "continuous", dataType: "double")
+ end
+ end
+ end
+ end
+
+ def pmml_transformation_dictionary(xml)
+ if @text_features.any?
+ xml.TransformationDictionary do
+ @text_features.each do |k, text_options|
+ xml.DefineFunction(name: "#{k}Transform", optype: "continuous") do
+ xml.ParameterField(name: "text")
+ xml.ParameterField(name: "term")
+ xml.TextIndex(textField: "text", localTermWeights: "termFrequency", wordSeparatorCharacterRE: text_options[:tokenizer].source, isCaseSensitive: !!text_options[:case_sensitive]) do
+ xml.FieldRef(field: "term")
+ end
+ end
+ end
+ end
+ end
+ end
+
+ def pmml_local_transformations(xml)
+ if @text_features.any?
+ xml.LocalTransformations do
+ @text_features.each do |k, _|
+ @text_encoders[k].vocabulary.each do |v|
+ xml.DerivedField(name: display_field([k, v]), optype: "continuous", dataType: "integer") do
+ xml.Apply(function: "#{k}Transform") do
+ xml.FieldRef(field: k)
+ xml.Constant v
+ end
+ end
+ end
+ end
+ end
end
end
end
end