lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.2.8 vs lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.2.9

- old
+ new

@@ -49,14 +49,16 @@ # @param random_seed [Integer] The seed value using to initialize the random generator. # It is used to randomly determine the order of features when deciding spliting point. def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, - max_features: max_features, random_seed: random_seed) + max_features: max_features, random_seed: random_seed) SVMKit::Validation.check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf) SVMKit::Validation.check_params_string(criterion: criterion) - + SVMKit::Validation.check_params_positive(n_estimators: n_estimators, max_depth: max_depth, + max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, + max_features: max_features) @params = {} @params[:n_estimators] = n_estimators @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @@ -76,9 +78,10 @@ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model. # @return [RandomForestClassifier] The learned classifier itself. def fit(x, y) SVMKit::Validation.check_sample_array(x) SVMKit::Validation.check_label_array(y) + SVMKit::Validation.check_sample_label_size(x, y) # Initialize some variables. n_samples, n_features = x.shape @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer) @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort)