lib/rumale/tree/base_decision_tree.rb in rumale-0.12.0 vs lib/rumale/tree/base_decision_tree.rb in rumale-0.12.1

- old
+ new

@@ -60,18 +60,20 @@ end end def build_tree(x, y) y = y.expand_dims(1).dup if y.shape[1].nil? + @feature_ids = Array.new(x.shape[1]) { |v| v } @tree = grow_node(0, x, y, impurity(y)) + @feature_ids = nil nil end - def grow_node(depth, x, y, whole_impurity) + def grow_node(depth, x, y, impurity) # intialize node. n_samples, n_features = x.shape - node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples) + node = Node.new(depth: depth, impurity: impurity, n_samples: n_samples) # terminate growing. unless @params[:max_leaf_nodes].nil? return nil if @n_leaves >= @params[:max_leaf_nodes] end @@ -85,11 +87,11 @@ return put_leaf(node, y) if stop_growing?(y) # calculate optimal parameters. feature_id, left_imp, right_imp, threshold, gain = - rand_ids(n_features).map { |n| [n, *best_split(x[true, n], y, whole_impurity)] }.max_by(&:last) + rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last) return put_leaf(node, y) if gain.nil? || gain.zero? left_ids = x[true, feature_id].le(threshold).where right_ids = x[true, feature_id].gt(threshold).where @@ -110,11 +112,11 @@ def put_leaf(_node, _y) raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}." end - def rand_ids(n) - [*0...n].sample(@params[:max_features], random: @sub_rng) + def rand_ids + @feature_ids.sample(@params[:max_features], random: @sub_rng) end def best_split(_features, _y, _impurity) raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}." end