lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.1 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.6.2

- old
+ new

@@ -89,12 +89,13 @@ SVMKit::Validation.check_label_array(y) SVMKit::Validation.check_sample_label_size(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min - @classes = Numo::Int32.asarray(y.to_a.uniq.sort) - build_tree(x, y) + uniq_y = y.to_a.uniq.sort + @classes = Numo::Int32.asarray(uniq_y) + build_tree(x, y.map { |v| uniq_y.index(v) }) eval_importance(n_samples, n_features) self end # Predict class labels for samples. @@ -172,50 +173,49 @@ end def build_tree(x, y) @n_leaves = 0 @leaf_labels = [] - @tree = grow_node(0, x, y) + @tree = grow_node(0, x, y, impurity(y)) @leaf_labels = Numo::Int32[*@leaf_labels] nil end - def grow_node(depth, x, y) - if @params[:max_leaf_nodes].is_a?(Integer) + def grow_node(depth, x, y, whole_impurity) + unless @params[:max_leaf_nodes].nil? return nil if @n_leaves >= @params[:max_leaf_nodes] end n_samples, n_features = x.shape - if @params[:min_samples_leaf].is_a?(Integer) - return nil if n_samples <= @params[:min_samples_leaf] - end + return nil if n_samples <= @params[:min_samples_leaf] - node = Node.new(depth: depth, impurity: impurity(y), n_samples: n_samples) + node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples) return put_leaf(node, y) if y.to_a.uniq.size == 1 - if @params[:max_depth].is_a?(Integer) + unless @params[:max_depth].nil? return put_leaf(node, y) if depth == @params[:max_depth] end - feature_id, threshold, left_ids, right_ids, max_gain = - rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y)] }.max_by(&:last) - return put_leaf(node, y) if max_gain.nil? - return put_leaf(node, y) if max_gain.zero? + feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain = + rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last) - node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids]) - node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids]) + return put_leaf(node, y) if gain.nil? || gain.zero? + + node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], left_impurity) + node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], right_impurity) + return put_leaf(node, y) if node.left.nil? && node.right.nil? node.feature_id = feature_id node.threshold = threshold node.leaf = false node end def put_leaf(node, y) - node.probs = Numo::DFloat.cast(@classes.map { |c| y.eq(c).count_true }) / node.n_samples + node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f node.leaf = true node.leaf_id = @n_leaves @n_leaves += 1 @leaf_labels.push(@classes[node.probs.max_index]) node @@ -223,39 +223,35 @@ def rand_ids(n) [*0...n].sample(@params[:max_features], random: @rng) end - def best_split(features, labels) + def best_split(features, labels, whole_impurity) + n_samples = labels.size features.to_a.uniq.sort.each_cons(2).map do |l, r| threshold = 0.5 * (l + r) - left_ids, right_ids = splited_ids(features, threshold) - [threshold, left_ids, right_ids, gain(labels, labels[left_ids], labels[right_ids])] + left_ids = features.le(threshold).where + right_ids = features.gt(threshold).where + left_impurity = impurity(labels[left_ids]) + right_impurity = impurity(labels[right_ids]) + gain = whole_impurity - + left_impurity * left_ids.size.fdiv(n_samples) - + right_impurity * right_ids.size.fdiv(n_samples) + [threshold, left_ids, right_ids, left_impurity, right_impurity, gain] end.max_by(&:last) end - def splited_ids(features, threshold) - [features.le(threshold).where, features.gt(threshold).where] - end - - def gain(labels, labels_left, labels_right) - prob_left = labels_left.size.fdiv(labels.size) - prob_right = labels_right.size.fdiv(labels.size) - impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right) - end - def impurity(labels) - cls = labels.to_a.uniq.sort - cls.size == 1 ? 0.0 : send(@criterion, Numo::DFloat[*(cls.map { |c| labels.eq(c).count_true.fdiv(labels.size) })]) + send(@criterion, labels.bincount / labels.size.to_f) end def gini(posterior_probs) 1.0 - (posterior_probs * posterior_probs).sum end def entropy(posterior_probs) - -(posterior_probs * Numo::NMath.log(posterior_probs)).sum + -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum end def eval_importance(n_samples, n_features) @feature_importances = Numo::DFloat.zeros(n_features) eval_importance_at_node(@tree) @@ -267,10 +263,11 @@ def eval_importance_at_node(node) return nil if node.leaf return nil if node.left.nil? || node.right.nil? gain = node.n_samples * node.impurity - - node.left.n_samples * node.left.impurity - node.right.n_samples * node.right.impurity + node.left.n_samples * node.left.impurity - + node.right.n_samples * node.right.impurity @feature_importances[node.feature_id] += gain eval_importance_at_node(node.left) eval_importance_at_node(node.right) end end