lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.3.1 vs lib/svmkit/tree/decision_tree_classifier.rb in svmkit-0.3.2

- old
+ new

@@ -1,78 +1,15 @@ # frozen_string_literal: true require 'svmkit/validation' require 'svmkit/base/base_estimator' require 'svmkit/base/classifier' +require 'svmkit/tree/node' module SVMKit # This module consists of the classes that implement tree models. module Tree - # Node is a class that implements node used for construction of decision tree. - # This class is used for internal data structures. - class Node - # @!visibility private - attr_accessor :depth, :impurity, :n_samples, :probs, :leaf, :leaf_id, :left, :right, :feature_id, :threshold - - # Create a new node for decision tree. - # - # @param depth [Integer] The depth of the node in tree. - # @param impurity [Float] The impurity of the node. - # @param n_samples [Integer] The number of the samples in the node. - # @param probs [Float] The probability of the node. - # @param leaf [Boolean] The flag indicating whether the node is a leaf. - # @param leaf_id [Integer] The leaf index of the node. - # @param left [Node] The left node. - # @param right [Node] The right node. - # @param feature_id [Integer] The feature index used for evaluation. - # @param threshold [Float] The threshold value of the feature for splitting the node. - def initialize(depth: 0, impurity: 0.0, n_samples: 0, probs: 0.0, - leaf: true, leaf_id: 0, - left: nil, right: nil, feature_id: 0, threshold: 0.0) - @depth = depth - @impurity = impurity - @n_samples = n_samples - @probs = probs - @leaf = leaf - @leaf_id = leaf_id - @left = left - @right = right - @feature_id = feature_id - @threshold = threshold - end - - # Dump marshal data. - # @return [Hash] The marshal data about Node - def marshal_dump - { depth: @depth, - impurity: @impurity, - n_samples: @n_samples, - probs: @probs, - leaf: @leaf, - leaf_id: @leaf_id, - left: @left, - right: @right, - feature_id: @feature_id, - threshold: @threshold } - end - - # Load marshal data. - # @return [nil] - def marshal_load(obj) - @depth = obj[:depth] - @impurity = obj[:impurity] - @n_samples = obj[:n_samples] - @probs = obj[:probs] - @leaf = obj[:leaf] - @leaf_id = obj[:leaf_id] - @left = obj[:left] - @right = obj[:right] - @feature_id = obj[:feature_id] - @threshold = obj[:threshold] - end - end - # DecisionTreeClassifier is a class that implements decision tree for classification. # # @example # estimator = # SVMKit::Tree::DecisionTreeClassifier.new( @@ -94,10 +31,10 @@ # Return the learned tree. # @return [Node] attr_reader :tree - # Return the random generator for performing random sampling in the Pegasos algorithm. + # Return the random generator for random selection of feature index. # @return [Random] attr_reader :rng # Return the labels assigned each leaf. # @return [Numo::Int32] (size: n_leafs)