Sha256: c3db46f111e65dd189aeeb518f84bda2af990acda61296dbce8c1297a0f2320e
Contents?: true
Size: 1.65 KB
Versions: 1
Compression:
Stored size: 1.65 KB
Contents
module Eps class BaseEstimator def train(data, y, target: nil, **options) # TODO more performant conversion if daru?(data) x = data.dup x = x.delete_vector(target) if target else x = data.map(&:dup) x.each { |r| r.delete(target) } if target end y = prep_y(y.to_a) if x.size != y.size raise "Number of samples differs from target" end @x = x @y = y @target = target || "target" end def predict(x) singular = !(x.is_a?(Array) || daru?(x)) x = [x] if singular pred = _predict(x) singular ? pred[0] : pred end def evaluate(data, y = nil, target: nil) target ||= @target raise ArgumentError, "missing target" if !target && !y actual = y actual ||= if daru?(data) data[target].to_a else data.map { |v| v[target] } end actual = prep_y(actual) estimated = predict(data) self.class.metrics(actual, estimated) end private def categorical?(v) !v.is_a?(Numeric) end def daru?(x) defined?(Daru) && x.is_a?(Daru::DataFrame) end def flip_target(target) target.is_a?(String) ? target.to_sym : target.to_s end def prep_y(y) y.each do |yi| raise "Target missing in data" if yi.nil? end y end # determine if target is a string or symbol def prep_target(target, data) if daru?(data) data.has_vector?(target) ? target : flip_target(target) else x = data[0] || {} x[target] ? target : flip_target(target) end end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
eps-0.2.1 | lib/eps/base_estimator.rb |