lib/rumale/tree/base_decision_tree.rb in rumale-0.12.0 vs lib/rumale/tree/base_decision_tree.rb in rumale-0.12.1
- old
+ new
@@ -60,18 +60,20 @@
end
end
def build_tree(x, y)
y = y.expand_dims(1).dup if y.shape[1].nil?
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
@tree = grow_node(0, x, y, impurity(y))
+ @feature_ids = nil
nil
end
- def grow_node(depth, x, y, whole_impurity)
+ def grow_node(depth, x, y, impurity)
# intialize node.
n_samples, n_features = x.shape
- node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
+ node = Node.new(depth: depth, impurity: impurity, n_samples: n_samples)
# terminate growing.
unless @params[:max_leaf_nodes].nil?
return nil if @n_leaves >= @params[:max_leaf_nodes]
end
@@ -85,11 +87,11 @@
return put_leaf(node, y) if stop_growing?(y)
# calculate optimal parameters.
feature_id, left_imp, right_imp, threshold, gain =
- rand_ids(n_features).map { |n| [n, *best_split(x[true, n], y, whole_impurity)] }.max_by(&:last)
+ rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last)
return put_leaf(node, y) if gain.nil? || gain.zero?
left_ids = x[true, feature_id].le(threshold).where
right_ids = x[true, feature_id].gt(threshold).where
@@ -110,11 +112,11 @@
def put_leaf(_node, _y)
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
end
- def rand_ids(n)
- [*0...n].sample(@params[:max_features], random: @sub_rng)
+ def rand_ids
+ @feature_ids.sample(@params[:max_features], random: @sub_rng)
end
def best_split(_features, _y, _impurity)
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
end