lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.0 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.1

- old
+ new

@@ -211,11 +211,11 @@ node.leaf = false node end def put_leaf(node, y) - node.probs = Numo::DFloat[*(@classes.to_a.map { |c| y.eq(c).count })] / node.n_samples + node.probs = Numo::DFloat.cast(@classes.map { |c| y.eq(c).count_true }) / node.n_samples node.leaf = true node.leaf_id = @n_leaves @n_leaves += 1 @leaf_labels.push(@classes[node.probs.max_index]) node @@ -232,21 +232,21 @@ [threshold, left_ids, right_ids, gain(labels, labels[left_ids], labels[right_ids])] end.max_by(&:last) end def splited_ids(features, threshold) - [features.le(threshold).where.to_a, features.gt(threshold).where.to_a] + [features.le(threshold).where, features.gt(threshold).where] end def gain(labels, labels_left, labels_right) - prob_left = labels_left.size / labels.size.to_f - prob_right = labels_right.size / labels.size.to_f + prob_left = labels_left.size.fdiv(labels.size) + prob_right = labels_right.size.fdiv(labels.size) impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right) end def impurity(labels) - posterior_probs = Numo::DFloat[*(labels.to_a.uniq.sort.map { |c| labels.eq(c).count })] / labels.size.to_f - send(@criterion, posterior_probs) + cls = labels.to_a.uniq.sort + cls.size == 1 ? 0.0 : send(@criterion, Numo::DFloat[*(cls.map { |c| labels.eq(c).count_true.fdiv(labels.size) })]) end def gini(posterior_probs) 1.0 - (posterior_probs * posterior_probs).sum end