lib/rumale/tree/decision_tree_regressor.rb in rumale-tree-0.27.0 vs lib/rumale/tree/decision_tree_regressor.rb in rumale-tree-0.28.0
- old
+ new
@@ -66,10 +66,11 @@
n_samples, n_features = x.shape
@params[:max_features] = n_features if @params[:max_features].nil?
@params[:max_features] = [@params[:max_features], n_features].min
@n_leaves = 0
@leaf_values = []
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
@sub_rng = @rng.dup
build_tree(x, y)
eval_importance(n_samples, n_features)
@leaf_values = Numo::DFloat.cast(@leaf_values)
@leaf_values = @leaf_values.flatten.dup if @leaf_values.shape[1] == 1
@@ -86,12 +87,14 @@
@leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
end
private
- def stop_growing?(y)
- y.to_a.uniq.size == 1
+ def build_tree(x, y)
+ y = y.expand_dims(1).dup if y.shape[1].nil?
+ @tree = grow_node(0, x, y, impurity(y))
+ nil
end
def put_leaf(node, y)
node.probs = nil
node.leaf = true
@@ -104,10 +107,10 @@
def best_split(f, y, impurity)
find_split_params(@params[:criterion], impurity, f.sort_index, f, y)
end
def impurity(y)
- node_impurity(@params[:criterion], y.to_a)
+ node_impurity(@params[:criterion], y)
end
end
end
end