lib/svmkit/linear_model/svc.rb in svmkit-0.2.5 vs lib/svmkit/linear_model/svc.rb in svmkit-0.2.6
- old
+ new
@@ -1,5 +1,7 @@
+# frozen_string_literal: true
+
require 'svmkit/base/base_estimator'
require 'svmkit/base/classifier'
module SVMKit
# This module consists of the classes that implement generalized linear models.
@@ -114,20 +116,9 @@
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
def predict(x)
Numo::Int32.cast(decision_function(x).map { |v| v >= 0 ? 1 : -1 })
- end
-
- # Claculate the mean accuracy of the given testing data.
- #
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
- # @return [Float] Mean accuracy
- def score(x, y)
- p = predict(x)
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
- n_hits / y.size.to_f
end
# Dump marshal data.
# @return [Hash] The marshal data about SVC.
def marshal_dump