lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.0 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.1
- old
+ new
@@ -211,11 +211,11 @@
node.leaf = false
node
end
def put_leaf(node, y)
- node.probs = Numo::DFloat[*(@classes.to_a.map { |c| y.eq(c).count })] / node.n_samples
+ node.probs = Numo::DFloat.cast(@classes.map { |c| y.eq(c).count_true }) / node.n_samples
node.leaf = true
node.leaf_id = @n_leaves
@n_leaves += 1
@leaf_labels.push(@classes[node.probs.max_index])
node
@@ -232,21 +232,21 @@
[threshold, left_ids, right_ids, gain(labels, labels[left_ids], labels[right_ids])]
end.max_by(&:last)
end
def splited_ids(features, threshold)
- [features.le(threshold).where.to_a, features.gt(threshold).where.to_a]
+ [features.le(threshold).where, features.gt(threshold).where]
end
def gain(labels, labels_left, labels_right)
- prob_left = labels_left.size / labels.size.to_f
- prob_right = labels_right.size / labels.size.to_f
+ prob_left = labels_left.size.fdiv(labels.size)
+ prob_right = labels_right.size.fdiv(labels.size)
impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right)
end
def impurity(labels)
- posterior_probs = Numo::DFloat[*(labels.to_a.uniq.sort.map { |c| labels.eq(c).count })] / labels.size.to_f
- send(@criterion, posterior_probs)
+ cls = labels.to_a.uniq.sort
+ cls.size == 1 ? 0.0 : send(@criterion, Numo::DFloat[*(cls.map { |c| labels.eq(c).count_true.fdiv(labels.size) })])
end
def gini(posterior_probs)
1.0 - (posterior_probs * posterior_probs).sum
end