lib/rumale/tree/base_decision_tree.rb in rumale-0.18.6 vs lib/rumale/tree/base_decision_tree.rb in rumale-0.18.7
- old
+ new
@@ -51,10 +51,11 @@
def apply_at_node(node, sample)
return node.leaf_id if node.leaf
return apply_at_node(node.left, sample) if node.right.nil?
return apply_at_node(node.right, sample) if node.left.nil?
+
if sample[node.feature_id] <= node.threshold
apply_at_node(node.left, sample)
else
apply_at_node(node.right, sample)
end
@@ -136,9 +137,10 @@
end
def eval_importance_at_node(node)
return nil if node.leaf
return nil if node.left.nil? || node.right.nil?
+
gain = node.n_samples * node.impurity -
node.left.n_samples * node.left.impurity -
node.right.n_samples * node.right.impurity
@feature_importances[node.feature_id] += gain
eval_importance_at_node(node.left)