lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.6.2 vs lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.6.3
- old
+ new
@@ -83,26 +83,26 @@
SVMKit::Validation.check_sample_array(x)
SVMKit::Validation.check_label_array(y)
SVMKit::Validation.check_sample_label_size(x, y)
# Initialize some variables.
n_samples, n_features = x.shape
- @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
- @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
+ @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
+ @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)
+ @feature_importances = Numo::DFloat.zeros(n_features)
# Construct forest.
- @estimators = Array.new(@params[:n_estimators]) do |_n|
+ @estimators = Array.new(@params[:n_estimators]) do
tree = Tree::DecisionTreeClassifier.new(
criterion: @params[:criterion], max_depth: @params[:max_depth],
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
- max_features: @params[:max_features], random_seed: @params[:random_seed]
+ max_features: @params[:max_features], random_seed: @rng.rand(int_max)
)
bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
+ @feature_importances += tree.feature_importances
+ tree
end
- # Calculate feature importances.
- @feature_importances = Numo::DFloat.zeros(n_features)
- @estimators.each { |tree| @feature_importances += tree.feature_importances }
@feature_importances /= @feature_importances.sum
self
end
# Predict class labels for samples.
@@ -155,12 +155,15 @@
end
# Dump marshal data.
# @return [Hash] The marshal data about RandomForestClassifier.
def marshal_dump
- { params: @params, estimators: @estimators, classes: @classes,
- feature_importances: @feature_importances, rng: @rng }
+ { params: @params,
+ estimators: @estimators,
+ classes: @classes,
+ feature_importances: @feature_importances,
+ rng: @rng }
end
# Load marshal data.
# @return [nil]
def marshal_load(obj)
@@ -168,9 +171,15 @@
@estimators = obj[:estimators]
@classes = obj[:classes]
@feature_importances = obj[:feature_importances]
@rng = obj[:rng]
nil
+ end
+
+ private
+
+ def int_max
+ @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1
end
end
end
end