lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.6.2 vs lib/svmkit/ensemble/random_forest_classifier.rb in svmkit-0.6.3

- old
+ new

@@ -83,26 +83,26 @@ SVMKit::Validation.check_sample_array(x) SVMKit::Validation.check_label_array(y) SVMKit::Validation.check_sample_label_size(x, y) # Initialize some variables. n_samples, n_features = x.shape - @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer) - @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min + @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer) + @params[:max_features] = [[1, @params[:max_features]].max, n_features].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) + @feature_importances = Numo::DFloat.zeros(n_features) # Construct forest. - @estimators = Array.new(@params[:n_estimators]) do |_n| + @estimators = Array.new(@params[:n_estimators]) do tree = Tree::DecisionTreeClassifier.new( criterion: @params[:criterion], max_depth: @params[:max_depth], max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf], - max_features: @params[:max_features], random_seed: @params[:random_seed] + max_features: @params[:max_features], random_seed: @rng.rand(int_max) ) bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) } tree.fit(x[bootstrap_ids, true], y[bootstrap_ids]) + @feature_importances += tree.feature_importances + tree end - # Calculate feature importances. - @feature_importances = Numo::DFloat.zeros(n_features) - @estimators.each { |tree| @feature_importances += tree.feature_importances } @feature_importances /= @feature_importances.sum self end # Predict class labels for samples. @@ -155,12 +155,15 @@ end # Dump marshal data. # @return [Hash] The marshal data about RandomForestClassifier. def marshal_dump - { params: @params, estimators: @estimators, classes: @classes, - feature_importances: @feature_importances, rng: @rng } + { params: @params, + estimators: @estimators, + classes: @classes, + feature_importances: @feature_importances, + rng: @rng } end # Load marshal data. # @return [nil] def marshal_load(obj) @@ -168,9 +171,15 @@ @estimators = obj[:estimators] @classes = obj[:classes] @feature_importances = obj[:feature_importances] @rng = obj[:rng] nil + end + + private + + def int_max + @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1 end end end end