lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.7.0 vs lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.7.1
- old
+ new
@@ -1,8 +1,9 @@
# frozen_string_literal: true
require 'svmkit/validation'
+require 'svmkit/values'
require 'svmkit/base/base_estimator'
require 'svmkit/base/classifier'
require 'svmkit/tree/decision_tree_classifier'
module SVMKit
@@ -18,10 +19,11 @@
# results = estimator.predict(testing_samples)
#
class RandomForestClassifier
include Base::BaseEstimator
include Base::Classifier
+ include Validation
# Return the set of estimators.
# @return [Array<DecisionTreeClassifier>]
attr_reader :estimators
@@ -48,19 +50,20 @@
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
# @param max_features [Integer] The number of features to consider when searching optimal split point.
# If nil is given, split process considers all features.
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
- def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
+ def initialize(n_estimators: 10,
+ criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
max_features: nil, random_seed: nil)
- SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
- max_features: max_features, random_seed: random_seed)
- SVMKit::Validation.check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
- SVMKit::Validation.check_params_string(criterion: criterion)
- SVMKit::Validation.check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
- max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
- max_features: max_features)
+ check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
+ max_features: max_features, random_seed: random_seed)
+ check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
+ check_params_string(criterion: criterion)
+ check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
+ max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
+ max_features: max_features)
@params = {}
@params[:n_estimators] = n_estimators
@params[:criterion] = criterion
@params[:max_depth] = max_depth
@params[:max_leaf_nodes] = max_leaf_nodes
@@ -78,13 +81,13 @@
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [RandomForestClassifier] The learned classifier itself.
def fit(x, y)
- SVMKit::Validation.check_sample_array(x)
- SVMKit::Validation.check_label_array(y)
- SVMKit::Validation.check_sample_label_size(x, y)
+ check_sample_array(x)
+ check_label_array(y)
+ check_sample_label_size(x, y)
# Initialize some variables.
n_samples, n_features = x.shape
@params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
@params[:max_features] = [[1, @params[:max_features]].max, n_features].min
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)
@@ -92,11 +95,11 @@
# Construct forest.
@estimators = Array.new(@params[:n_estimators]) do
tree = Tree::DecisionTreeClassifier.new(
criterion: @params[:criterion], max_depth: @params[:max_depth],
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
- max_features: @params[:max_features], random_seed: @rng.rand(int_max)
+ max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values::int_max)
)
bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
@feature_importances += tree.feature_importances
tree
@@ -108,11 +111,11 @@
# Predict class labels for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
def predict(x)
- SVMKit::Validation.check_sample_array(x)
+ check_sample_array(x)
n_samples, = x.shape
n_classes = @classes.size
classes_arr = @classes.to_a
ballot_box = Numo::DFloat.zeros(n_samples, n_classes)
@estimators.each do |tree|
@@ -128,11 +131,11 @@
# Predict probability for samples.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
def predict_proba(x)
- SVMKit::Validation.check_sample_array(x)
+ check_sample_array(x)
n_samples, = x.shape
n_classes = @classes.size
classes_arr = @classes.to_a
ballot_box = Numo::DFloat.zeros(n_samples, n_classes)
@estimators.each do |tree|
@@ -148,11 +151,11 @@
# Return the index of the leaf that each sample reached.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
def apply(x)
- SVMKit::Validation.check_sample_array(x)
+ check_sample_array(x)
Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose
end
# Dump marshal data.
# @return [Hash] The marshal data about RandomForestClassifier.
@@ -171,15 +174,9 @@
@estimators = obj[:estimators]
@classes = obj[:classes]
@feature_importances = obj[:feature_importances]
@rng = obj[:rng]
nil
- end
-
- private
-
- def int_max
- @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1
end
end
end
end