Sha256: abe6480cf83bf18b0fc8a74f99c3564f6c123ea92b66d4d837cb71aeecfe0162

Contents?: true

Size: 432 Bytes

Versions: 3

Compression:

Stored size: 432 Bytes

Contents

module VowpalWabbit
  class Classifier < Model
    def initialize(**params)
      super(loss_function: "logistic", **params)
    end

    def predict(x)
      predictions = super
      predictions.map { |v| v >= 0 ? 1 : -1 }
    end

    def score(x, y = nil)
      y_pred, y = predict_for_score(x, y)
      y_pred.map! { |v| v >= 0 ? 1 : -1 }
      y_pred.zip(y).select { |yp, yt| yp == yt }.count / y.count.to_f
    end
  end
end

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
vowpalwabbit-0.2.0 lib/vowpalwabbit/classifier.rb
vowpalwabbit-0.1.3 lib/vowpalwabbit/classifier.rb
vowpalwabbit-0.1.2 lib/vowpalwabbit/classifier.rb