lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.5.0 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.5.1

- old
+ new

@@ -110,12 +110,11 @@ # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities. # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample. def predict_proba(x) SVMKit::Validation.check_sample_array(x) - probs = Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })] - probs[true, @classes] + Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })] end # Return the index of the leaf that each sample reached. # # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels. @@ -212,14 +211,14 @@ node.leaf = false node end def put_leaf(node, y) - node.probs = y.bincount(minlength: @classes.max + 1) / node.n_samples.to_f + node.probs = Numo::DFloat[*(@classes.to_a.map { |c| y.eq(c).count })] / node.n_samples node.leaf = true node.leaf_id = @n_leaves @n_leaves += 1 - @leaf_labels.push(node.probs.max_index) + @leaf_labels.push(@classes[node.probs.max_index]) node end def rand_ids(n) [*0...n].sample(@params[:max_features], random: @rng)