lib/rumale/tree/gradient_tree_regressor.rb in rumale-0.23.0 vs lib/rumale/tree/gradient_tree_regressor.rb in rumale-0.23.1
- old
+ new
@@ -1,10 +1,10 @@
# frozen_string_literal: true
-require 'rumale/rumale'
require 'rumale/base/base_estimator'
require 'rumale/base/regressor'
+require 'rumale/rumaleext'
require 'rumale/tree/node'
module Rumale
module Tree
# GradientTreeRegressor is a class that implements decision tree for regression with exact gredy algorithm.
@@ -112,24 +112,28 @@
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
# @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
def apply(x)
x = check_convert_sample_array(x)
- Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })]
end
private
- def apply_at_node(node, sample)
- return node.leaf_id if node.leaf
- return apply_at_node(node.left, sample) if node.right.nil?
- return apply_at_node(node.right, sample) if node.left.nil?
-
- if sample[node.feature_id] <= node.threshold
- apply_at_node(node.left, sample)
- else
- apply_at_node(node.right, sample)
+ def partial_apply(tree, sample)
+ node = tree
+ until node.leaf
+ # :nocov:
+ node = if node.right.nil?
+ node.left
+ elsif node.left.nil?
+ node.right
+ # :nocov:
+ else
+ sample[node.feature_id] <= node.threshold ? node.left : node.right
+ end
end
+ node.leaf_id
end
def build_tree(x, y, g, h)
@feature_ids = Array.new(x.shape[1]) { |v| v }
@tree = grow_node(0, x, y, g, h)