lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.2.8 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.2.9
- old
+ new
@@ -5,10 +5,74 @@
require 'ostruct'
module SVMKit
# This module consists of the classes that implement tree models.
module Tree
+ # Node is a class that implements node used for construction of decision tree.
+ # This class is used for internal data structures.
+ class Node
+ # @!visibility private
+ attr_accessor :depth, :impurity, :n_samples, :probs, :leaf, :leaf_id, :left, :right, :feature_id, :threshold
+
+ # Create a new node for decision tree.
+ #
+ # @param depth [Integer] The depth of the node in tree.
+ # @param impurity [Float] The impurity of the node.
+ # @param n_samples [Integer] The number of the samples in the node.
+ # @param probs [Float] The probability of the node.
+ # @param leaf [Boolean] The flag indicating whether the node is a leaf.
+ # @param leaf_id [Integer] The leaf index of the node.
+ # @param left [Node] The left node.
+ # @param right [Node] The right node.
+ # @param feature_id [Integer] The feature index used for evaluation.
+ # @param threshold [Float] The threshold value of the feature for splitting the node.
+ def initialize(depth: 0, impurity: 0.0, n_samples: 0, probs: 0.0,
+ leaf: true, leaf_id: 0,
+ left: nil, right: nil, feature_id: 0, threshold: 0.0)
+ @depth = depth
+ @impurity = impurity
+ @n_samples = n_samples
+ @probs = probs
+ @leaf = leaf
+ @leaf_id = leaf_id
+ @left = left
+ @right = right
+ @feature_id = feature_id
+ @threshold = threshold
+ end
+
+ # Dump marshal data.
+ # @return [Hash] The marshal data about Node
+ def marshal_dump
+ { depth: @depth,
+ impurity: @impurity,
+ n_samples: @n_samples,
+ probs: @probs,
+ leaf: @leaf,
+ leaf_id: @leaf_id,
+ left: @left,
+ right: @right,
+ feature_id: @feature_id,
+ threshold: @threshold }
+ end
+
+ # Load marshal data.
+ # @return [nil]
+ def marshal_load(obj)
+ @depth = obj[:depth]
+ @impurity = obj[:impurity]
+ @n_samples = obj[:n_samples]
+ @probs = obj[:probs]
+ @leaf = obj[:leaf]
+ @leaf_id = obj[:leaf_id]
+ @left = obj[:left]
+ @right = obj[:right]
+ @feature_id = obj[:feature_id]
+ @threshold = obj[:threshold]
+ end
+ end
+
# DecisionTreeClassifier is a class that implements decision tree for classification.
#
# @example
# estimator =
# SVMKit::Tree::DecisionTreeClassifier.new(
@@ -27,11 +91,11 @@
# Return the importance for each feature.
# @return [Numo::DFloat] (size: n_features)
attr_reader :feature_importances
# Return the learned tree.
- # @return [OpenStruct]
+ # @return [Node]
attr_reader :tree
# Return the random generator for performing random sampling in the Pegasos algorithm.
# @return [Random]
attr_reader :rng
@@ -53,22 +117,25 @@
# @param random_seed [Integer] The seed value using to initialize the random generator.
# It is used to randomly determine the order of features when deciding spliting point.
def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
random_seed: nil)
SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
- max_features: max_features, random_seed: random_seed)
+ max_features: max_features, random_seed: random_seed)
SVMKit::Validation.check_params_integer(min_samples_leaf: min_samples_leaf)
SVMKit::Validation.check_params_string(criterion: criterion)
-
+ SVMKit::Validation.check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
+ min_samples_leaf: min_samples_leaf, max_features: max_features)
@params = {}
@params[:criterion] = criterion
@params[:max_depth] = max_depth
@params[:max_leaf_nodes] = max_leaf_nodes
@params[:min_samples_leaf] = min_samples_leaf
@params[:max_features] = max_features
@params[:random_seed] = random_seed
@params[:random_seed] ||= srand
+ @criterion = :gini
+ @criterion = :entropy if @params[:criterion] == 'entropy'
@tree = nil
@classes = nil
@feature_importances = nil
@n_leaves = nil
@leaf_labels = nil
@@ -81,13 +148,14 @@
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [DecisionTreeClassifier] The learned classifier itself.
def fit(x, y)
SVMKit::Validation.check_sample_array(x)
SVMKit::Validation.check_label_array(y)
+ SVMKit::Validation.check_sample_label_size(x, y)
n_samples, n_features = x.shape
- @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
- @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
+ @params[:max_features] = n_features if @params[:max_features].nil?
+ @params[:max_features] = [@params[:max_features], n_features].min
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)
build_tree(x, y)
eval_importance(n_samples, n_features)
self
end
@@ -123,10 +191,11 @@
# Dump marshal data.
# @return [Hash] The marshal data about DecisionTreeClassifier
def marshal_dump
{ params: @params,
classes: @classes,
+ criterion: @criterion,
tree: @tree,
feature_importances: @feature_importances,
leaf_labels: @leaf_labels,
rng: @rng }
end
@@ -134,10 +203,11 @@
# Load marshal data.
# @return [nil]
def marshal_load(obj)
@params = obj[:params]
@classes = obj[:classes]
+ @criterion = obj[:criterion]
@tree = obj[:tree]
@feature_importances = obj[:feature_importances]
@leaf_labels = obj[:leaf_labels]
@rng = obj[:rng]
nil
@@ -181,11 +251,11 @@
n_samples, n_features = x.shape
if @params[:min_samples_leaf].is_a?(Integer)
return nil if n_samples <= @params[:min_samples_leaf]
end
- node = OpenStruct.new(depth: depth, impurity: impurity(y), n_samples: n_samples)
+ node = Node.new(depth: depth, impurity: impurity(y), n_samples: n_samples)
return put_leaf(node, y) if y.to_a.uniq.size == 1
if @params[:max_depth].is_a?(Integer)
return put_leaf(node, y) if depth == @params[:max_depth]
@@ -236,19 +306,19 @@
prob_right = labels_right.size / labels.size.to_f
impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right)
end
def impurity(labels)
- posterior_probs = labels.to_a.uniq.sort.map { |c| labels.eq(c).count / labels.size.to_f }
- @params[:criterion] == 'entropy' ? entropy(posterior_probs) : gini(posterior_probs)
+ posterior_probs = Numo::DFloat[*(labels.to_a.uniq.sort.map { |c| labels.eq(c).count })] / labels.size.to_f
+ send(@criterion, posterior_probs)
end
def gini(posterior_probs)
- 1.0 - posterior_probs.map { |p| p**2 }.inject(:+)
+ 1.0 - (posterior_probs * posterior_probs).sum
end
def entropy(posterior_probs)
- -posterior_probs.map { |p| p * Math.log(p) }.inject(:+)
+ -(posterior_probs * Numo::NMath.log(posterior_probs)).sum
end
def eval_importance(n_samples, n_features)
@feature_importances = Numo::DFloat.zeros(n_features)
eval_importance_at_node(@tree)