Sha256: e53c785ece52f9d2a994cba796d9b8777bb71903731c94c69477d51c6f53aa71

Contents?: true

Size: 1.96 KB

Versions: 7

Compression:

Stored size: 1.96 KB

Contents

# frozen_string_literal: true

require 'rumale/base/estimator'
require 'rumale/base/classifier'
require 'rumale/validation'

module Rumale
  module NaiveBayes
    # BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier.
    # This class is used internally.
    class BaseNaiveBayes < ::Rumale::Base::Estimator
      include ::Rumale::Base::Classifier

      def initialize # rubocop:disable Lint/UselessMethodDefinition
        super()
      end

      # Predict class labels for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
      # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
      def predict(x)
        x = ::Rumale::Validation.check_convert_sample_array(x)

        n_samples = x.shape.first
        decision_values = decision_function(x)
        Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
      end

      # Predict log-probability for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
      # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
      def predict_log_proba(x)
        x = ::Rumale::Validation.check_convert_sample_array(x)

        n_samples, = x.shape
        log_likelihoods = decision_function(x)
        log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(axis: 1)).reshape(n_samples, 1)
      end

      # Predict probability for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
      # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
      def predict_proba(x)
        x = ::Rumale::Validation.check_convert_sample_array(x)

        Numo::NMath.exp(predict_log_proba(x)).abs
      end
    end
  end
end

Version data entries

7 entries across 7 versions & 1 rubygems

Version Path
rumale-naive_bayes-0.29.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-naive_bayes-0.28.1 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-naive_bayes-0.28.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-naive_bayes-0.27.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-naive_bayes-0.26.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-naive_bayes-0.25.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-naive_bayes-0.24.0 lib/rumale/naive_bayes/base_naive_bayes.rb