Sha256: 0344207964712dd3c97db661b345f116722303bc564ae542250664c035e599c7

Contents?: true

Size: 438 Bytes

Versions: 2

Compression:

Stored size: 438 Bytes

Contents

module VowpalWabbit
  class Classifier < Model
    def initialize(**params)
      super({loss_function: "logistic"}.merge(params))
    end

    def predict(x)
      predictions = super
      predictions.map { |v| v >= 0 ? 1 : -1 }
    end

    def score(x, y = nil)
      y_pred, y = predict_for_score(x, y)
      y_pred.map! { |v| v >= 0 ? 1 : -1 }
      y_pred.zip(y).select { |yp, yt| yp == yt }.count / y.count.to_f
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
vowpalwabbit-0.1.1 lib/vowpalwabbit/classifier.rb
vowpalwabbit-0.1.0 lib/vowpalwabbit/classifier.rb