lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.1 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.2
- old
+ new
@@ -89,12 +89,13 @@
SVMKit::Validation.check_label_array(y)
SVMKit::Validation.check_sample_label_size(x, y)
n_samples, n_features = x.shape
@params[:max_features] = n_features if @params[:max_features].nil?
@params[:max_features] = [@params[:max_features], n_features].min
- @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
- build_tree(x, y)
+ uniq_y = y.to_a.uniq.sort
+ @classes = Numo::Int32.asarray(uniq_y)
+ build_tree(x, y.map { |v| uniq_y.index(v) })
eval_importance(n_samples, n_features)
self
end
# Predict class labels for samples.
@@ -172,50 +173,49 @@
end
def build_tree(x, y)
@n_leaves = 0
@leaf_labels = []
- @tree = grow_node(0, x, y)
+ @tree = grow_node(0, x, y, impurity(y))
@leaf_labels = Numo::Int32[*@leaf_labels]
nil
end
- def grow_node(depth, x, y)
- if @params[:max_leaf_nodes].is_a?(Integer)
+ def grow_node(depth, x, y, whole_impurity)
+ unless @params[:max_leaf_nodes].nil?
return nil if @n_leaves >= @params[:max_leaf_nodes]
end
n_samples, n_features = x.shape
- if @params[:min_samples_leaf].is_a?(Integer)
- return nil if n_samples <= @params[:min_samples_leaf]
- end
+ return nil if n_samples <= @params[:min_samples_leaf]
- node = Node.new(depth: depth, impurity: impurity(y), n_samples: n_samples)
+ node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
return put_leaf(node, y) if y.to_a.uniq.size == 1
- if @params[:max_depth].is_a?(Integer)
+ unless @params[:max_depth].nil?
return put_leaf(node, y) if depth == @params[:max_depth]
end
- feature_id, threshold, left_ids, right_ids, max_gain =
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y)] }.max_by(&:last)
- return put_leaf(node, y) if max_gain.nil?
- return put_leaf(node, y) if max_gain.zero?
+ feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
+ rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids])
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids])
+ return put_leaf(node, y) if gain.nil? || gain.zero?
+
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], left_impurity)
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], right_impurity)
+
return put_leaf(node, y) if node.left.nil? && node.right.nil?
node.feature_id = feature_id
node.threshold = threshold
node.leaf = false
node
end
def put_leaf(node, y)
- node.probs = Numo::DFloat.cast(@classes.map { |c| y.eq(c).count_true }) / node.n_samples
+ node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f
node.leaf = true
node.leaf_id = @n_leaves
@n_leaves += 1
@leaf_labels.push(@classes[node.probs.max_index])
node
@@ -223,39 +223,35 @@
def rand_ids(n)
[*0...n].sample(@params[:max_features], random: @rng)
end
- def best_split(features, labels)
+ def best_split(features, labels, whole_impurity)
+ n_samples = labels.size
features.to_a.uniq.sort.each_cons(2).map do |l, r|
threshold = 0.5 * (l + r)
- left_ids, right_ids = splited_ids(features, threshold)
- [threshold, left_ids, right_ids, gain(labels, labels[left_ids], labels[right_ids])]
+ left_ids = features.le(threshold).where
+ right_ids = features.gt(threshold).where
+ left_impurity = impurity(labels[left_ids])
+ right_impurity = impurity(labels[right_ids])
+ gain = whole_impurity -
+ left_impurity * left_ids.size.fdiv(n_samples) -
+ right_impurity * right_ids.size.fdiv(n_samples)
+ [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
end.max_by(&:last)
end
- def splited_ids(features, threshold)
- [features.le(threshold).where, features.gt(threshold).where]
- end
-
- def gain(labels, labels_left, labels_right)
- prob_left = labels_left.size.fdiv(labels.size)
- prob_right = labels_right.size.fdiv(labels.size)
- impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right)
- end
-
def impurity(labels)
- cls = labels.to_a.uniq.sort
- cls.size == 1 ? 0.0 : send(@criterion, Numo::DFloat[*(cls.map { |c| labels.eq(c).count_true.fdiv(labels.size) })])
+ send(@criterion, labels.bincount / labels.size.to_f)
end
def gini(posterior_probs)
1.0 - (posterior_probs * posterior_probs).sum
end
def entropy(posterior_probs)
- -(posterior_probs * Numo::NMath.log(posterior_probs)).sum
+ -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
end
def eval_importance(n_samples, n_features)
@feature_importances = Numo::DFloat.zeros(n_features)
eval_importance_at_node(@tree)
@@ -267,10 +263,11 @@
def eval_importance_at_node(node)
return nil if node.leaf
return nil if node.left.nil? || node.right.nil?
gain = node.n_samples * node.impurity -
- node.left.n_samples * node.left.impurity - node.right.n_samples * node.right.impurity
+ node.left.n_samples * node.left.impurity -
+ node.right.n_samples * node.right.impurity
@feature_importances[node.feature_id] += gain
eval_importance_at_node(node.left)
eval_importance_at_node(node.right)
end
end