lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.5.0 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.5.1
- old
+ new
@@ -110,12 +110,11 @@
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
def predict_proba(x)
SVMKit::Validation.check_sample_array(x)
- probs = Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })]
- probs[true, @classes]
+ Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })]
end
# Return the index of the leaf that each sample reached.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
@@ -212,14 +211,14 @@
node.leaf = false
node
end
def put_leaf(node, y)
- node.probs = y.bincount(minlength: @classes.max + 1) / node.n_samples.to_f
+ node.probs = Numo::DFloat[*(@classes.to_a.map { |c| y.eq(c).count })] / node.n_samples
node.leaf = true
node.leaf_id = @n_leaves
@n_leaves += 1
- @leaf_labels.push(node.probs.max_index)
+ @leaf_labels.push(@classes[node.probs.max_index])
node
end
def rand_ids(n)
[*0...n].sample(@params[:max_features], random: @rng)