lib/rumale/tree/gradient_tree_regressor.rb in rumale-0.23.0 vs lib/rumale/tree/gradient_tree_regressor.rb in rumale-0.23.1

- old
+ new

@@ -1,10 +1,10 @@ # frozen_string_literal: true -require 'rumale/rumale' require 'rumale/base/base_estimator' require 'rumale/base/regressor' +require 'rumale/rumaleext' require 'rumale/tree/node' module Rumale module Tree # GradientTreeRegressor is a class that implements decision tree for regression with exact gredy algorithm. @@ -112,24 +112,28 @@ # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels. # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample. def apply(x) x = check_convert_sample_array(x) - Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })] + Numo::Int32[*(Array.new(x.shape[0]) { |n| partial_apply(@tree, x[n, true]) })] end private - def apply_at_node(node, sample) - return node.leaf_id if node.leaf - return apply_at_node(node.left, sample) if node.right.nil? - return apply_at_node(node.right, sample) if node.left.nil? - - if sample[node.feature_id] <= node.threshold - apply_at_node(node.left, sample) - else - apply_at_node(node.right, sample) + def partial_apply(tree, sample) + node = tree + until node.leaf + # :nocov: + node = if node.right.nil? + node.left + elsif node.left.nil? + node.right + # :nocov: + else + sample[node.feature_id] <= node.threshold ? node.left : node.right + end end + node.leaf_id end def build_tree(x, y, g, h) @feature_ids = Array.new(x.shape[1]) { |v| v } @tree = grow_node(0, x, y, g, h)