lib/svmkit/model_selection/k_fold.rb in svmkit-0.2.7 vs lib/svmkit/model_selection/k_fold.rb in svmkit-0.2.8

- old
+ new

@@ -30,10 +30,14 @@ # # @param n_splits [Integer] The number of folds. # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset. # @param random_seed [Integer] The seed value using to initialize the random generator. def initialize(n_splits: 3, shuffle: false, random_seed: nil) + SVMKit::Validation.check_params_integer(n_splits: n_splits) + SVMKit::Validation.check_params_boolean(shuffle: shuffle) + SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed) + @n_splits = n_splits @shuffle = shuffle @random_seed = random_seed @random_seed ||= srand @rng = Random.new(@random_seed) @@ -41,14 +45,12 @@ # Generate data indices for K-fold cross validation. # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) # The dataset to be used to generate data indices for K-fold cross validation. - # @param y [Numo::Int32] (shape: [n_samples]) - # The labels to be used to generate data indices for stratified K-fold cross validation. - # This argument exists to unify the interface between the K-fold methods, it is not used in the method. # @return [Array] The set of data indices for constructing the training and testing dataset in each fold. - def split(x, y) # rubocop:disable Lint/UnusedMethodArgument + def split(x, _y = nil) + SVMKit::Validation.check_sample_array(x) # Initialize and check some variables. n_samples, = x.shape unless @n_splits.between?(2, n_samples) raise ArgumentError, 'The value of n_splits must be not less than 2 and not more than the number of samples.'