Sha256: 69a7645fc5adc4ba96b0b63c13be817606f04c9300b2436cb8e728d26c57f796

Contents?: true

Size: 425 Bytes

Versions: 1

Compression:

Stored size: 425 Bytes

Contents

module VowpalWabbit
  class Classifier < Model
    def initialize(**params)
      super(loss_function: "logistic", **params)
    end

    def predict(x)
      predictions = super
      predictions.map { |v| v >= 0 ? 1 : -1 }
    end

    def score(x, y = nil)
      y_pred, y = predict_for_score(x, y)
      y_pred.map! { |v| v >= 0 ? 1 : -1 }
      y_pred.zip(y).count { |yp, yt| yp == yt } / y.count.to_f
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
vowpalwabbit-0.3.0 lib/vowpalwabbit/classifier.rb