lib/svmkit/model_selection/cross_validation.rb in svmkit-0.2.9 vs lib/svmkit/model_selection/cross_validation.rb in svmkit-0.3.0
- old
+ new
@@ -1,8 +1,14 @@
# frozen_string_literal: true
+require 'svmkit/validation'
+require 'svmkit/base/base_estimator'
+require 'svmkit/base/classifier'
+require 'svmkit/base/regressor'
require 'svmkit/base/splitter'
+require 'svmkit/base/evaluator'
+require 'svmkit/evaluation_measure/log_loss'
module SVMKit
# This module consists of the classes for model validation techniques.
module ModelSelection
# CrossValidation is a class that evaluates a given classifier with cross-validation method.
@@ -49,32 +55,38 @@
end
# Perform the evalution of given classifier with cross-validation method.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
- # The dataset to be used to evaluate the classifier.
- # @param y [Numo::Int32] (shape: [n_samples])
- # The labels to be used to evaluate the classifier.
+ # The dataset to be used to evaluate the estimator.
+ # @param y [Numo::Int32 / Numo::DFloat] (shape: [n_samples] / [n_samples, n_outputs])
+ # The labels to be used to evaluate the classifier / The target values to be used to evaluate the regressor.
# @return [Hash] The report summarizing the results of cross-validation.
# * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
# * :test_score (Array<Float>) The scores of testing dataset for each split.
# * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
# the return_train_score is false.
def perform(x, y)
SVMKit::Validation.check_sample_array(x)
- SVMKit::Validation.check_label_array(y)
- SVMKit::Validation.check_sample_label_size(x, y)
+ if @estimator.is_a?(SVMKit::Base::Classifier)
+ SVMKit::Validation.check_label_array(y)
+ SVMKit::Validation.check_sample_label_size(x, y)
+ end
+ if @estimator.is_a?(SVMKit::Base::Regressor)
+ SVMKit::Validation.check_tvalue_array(y)
+ SVMKit::Validation.check_sample_tvalue_size(x, y)
+ end
# Initialize the report of cross validation.
report = { test_score: [], train_score: nil, fit_time: [] }
report[:train_score] = [] if @return_train_score
# Evaluate the estimator on each split.
@splitter.split(x, y).each do |train_ids, test_ids|
# Split dataset into training and testing dataset.
feature_ids = !kernel_machine? || train_ids
train_x = x[train_ids, feature_ids]
- train_y = y[train_ids]
+ train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true]
test_x = x[test_ids, feature_ids]
- test_y = y[test_ids]
+ test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true]
# Fit the estimator.
start_time = Time.now.to_i
@estimator.fit(train_x, train_y)
# Calculate scores and prepare the report.
report[:fit_time].push(Time.now.to_i - start_time)