lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.2.8 vs lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.2.9
- old
+ new
@@ -49,14 +49,16 @@
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
max_features: nil, random_seed: nil)
SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
- max_features: max_features, random_seed: random_seed)
+ max_features: max_features, random_seed: random_seed)
SVMKit::Validation.check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
SVMKit::Validation.check_params_string(criterion: criterion)
-
+ SVMKit::Validation.check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
+ max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
+ max_features: max_features)
@params = {}
@params[:n_estimators] = n_estimators
@params[:criterion] = criterion
@params[:max_depth] = max_depth
@params[:max_leaf_nodes] = max_leaf_nodes
@@ -76,9 +78,10 @@
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [RandomForestClassifier] The learned classifier itself.
def fit(x, y)
SVMKit::Validation.check_sample_array(x)
SVMKit::Validation.check_label_array(y)
+ SVMKit::Validation.check_sample_label_size(x, y)
# Initialize some variables.
n_samples, n_features = x.shape
@params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
@params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)