lib/rumale/tree/decision_tree_regressor.rb in rumale-tree-0.27.0 vs lib/rumale/tree/decision_tree_regressor.rb in rumale-tree-0.28.0

- old
+ new

@@ -66,10 +66,11 @@ n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min @n_leaves = 0 @leaf_values = [] + @feature_ids = Array.new(x.shape[1]) { |v| v } @sub_rng = @rng.dup build_tree(x, y) eval_importance(n_samples, n_features) @leaf_values = Numo::DFloat.cast(@leaf_values) @leaf_values = @leaf_values.flatten.dup if @leaf_values.shape[1] == 1 @@ -86,12 +87,14 @@ @leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup end private - def stop_growing?(y) - y.to_a.uniq.size == 1 + def build_tree(x, y) + y = y.expand_dims(1).dup if y.shape[1].nil? + @tree = grow_node(0, x, y, impurity(y)) + nil end def put_leaf(node, y) node.probs = nil node.leaf = true @@ -104,10 +107,10 @@ def best_split(f, y, impurity) find_split_params(@params[:criterion], impurity, f.sort_index, f, y) end def impurity(y) - node_impurity(@params[:criterion], y.to_a) + node_impurity(@params[:criterion], y) end end end end