lib/svmkit/model_selection/cross_validation.rb in svmkit-0.2.9 vs lib/svmkit/model_selection/cross_validation.rb in svmkit-0.3.0

- old
+ new

@@ -1,8 +1,14 @@ # frozen_string_literal: true +require 'svmkit/validation' +require 'svmkit/base/base_estimator' +require 'svmkit/base/classifier' +require 'svmkit/base/regressor' require 'svmkit/base/splitter' +require 'svmkit/base/evaluator' +require 'svmkit/evaluation_measure/log_loss' module SVMKit # This module consists of the classes for model validation techniques. module ModelSelection # CrossValidation is a class that evaluates a given classifier with cross-validation method. @@ -49,32 +55,38 @@ end # Perform the evalution of given classifier with cross-validation method. # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) - # The dataset to be used to evaluate the classifier. - # @param y [Numo::Int32] (shape: [n_samples]) - # The labels to be used to evaluate the classifier. + # The dataset to be used to evaluate the estimator. + # @param y [Numo::Int32 / Numo::DFloat] (shape: [n_samples] / [n_samples, n_outputs]) + # The labels to be used to evaluate the classifier / The target values to be used to evaluate the regressor. # @return [Hash] The report summarizing the results of cross-validation. # * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split. # * :test_score (Array<Float>) The scores of testing dataset for each split. # * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if # the return_train_score is false. def perform(x, y) SVMKit::Validation.check_sample_array(x) - SVMKit::Validation.check_label_array(y) - SVMKit::Validation.check_sample_label_size(x, y) + if @estimator.is_a?(SVMKit::Base::Classifier) + SVMKit::Validation.check_label_array(y) + SVMKit::Validation.check_sample_label_size(x, y) + end + if @estimator.is_a?(SVMKit::Base::Regressor) + SVMKit::Validation.check_tvalue_array(y) + SVMKit::Validation.check_sample_tvalue_size(x, y) + end # Initialize the report of cross validation. report = { test_score: [], train_score: nil, fit_time: [] } report[:train_score] = [] if @return_train_score # Evaluate the estimator on each split. @splitter.split(x, y).each do |train_ids, test_ids| # Split dataset into training and testing dataset. feature_ids = !kernel_machine? || train_ids train_x = x[train_ids, feature_ids] - train_y = y[train_ids] + train_y = y.shape[1].nil? ? y[train_ids] : y[train_ids, true] test_x = x[test_ids, feature_ids] - test_y = y[test_ids] + test_y = y.shape[1].nil? ? y[test_ids] : y[test_ids, true] # Fit the estimator. start_time = Time.now.to_i @estimator.fit(train_x, train_y) # Calculate scores and prepare the report. report[:fit_time].push(Time.now.to_i - start_time)