lib/svmkit/model_selection/k_fold.rb in svmkit-0.2.7 vs lib/svmkit/model_selection/k_fold.rb in svmkit-0.2.8
- old
+ new
@@ -30,10 +30,14 @@
#
# @param n_splits [Integer] The number of folds.
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
# @param random_seed [Integer] The seed value using to initialize the random generator.
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
+ SVMKit::Validation.check_params_integer(n_splits: n_splits)
+ SVMKit::Validation.check_params_boolean(shuffle: shuffle)
+ SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
+
@n_splits = n_splits
@shuffle = shuffle
@random_seed = random_seed
@random_seed ||= srand
@rng = Random.new(@random_seed)
@@ -41,14 +45,12 @@
# Generate data indices for K-fold cross validation.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
# The dataset to be used to generate data indices for K-fold cross validation.
- # @param y [Numo::Int32] (shape: [n_samples])
- # The labels to be used to generate data indices for stratified K-fold cross validation.
- # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
- def split(x, y) # rubocop:disable Lint/UnusedMethodArgument
+ def split(x, _y = nil)
+ SVMKit::Validation.check_sample_array(x)
# Initialize and check some variables.
n_samples, = x.shape
unless @n_splits.between?(2, n_samples)
raise ArgumentError,
'The value of n_splits must be not less than 2 and not more than the number of samples.'