Sha256: aded79d5fe70fd2df6fd7e65b123580ca7d32dcd473de5894e6d774d5b9b5e52

Contents?: true

Size: 1.84 KB

Versions: 25

Compression:

Stored size: 1.84 KB

Contents

# frozen_string_literal: true

require 'rumale/base/base_estimator'
require 'rumale/base/classifier'

module Rumale
  # This module consists of the classes that implement naive bayes models.
  module NaiveBayes
    # BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier.
    # This class is used internally.
    class BaseNaiveBayes
      include Base::BaseEstimator
      include Base::Classifier

      # Predict class labels for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
      # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
      def predict(x)
        x = check_convert_sample_array(x)
        n_samples = x.shape.first
        decision_values = decision_function(x)
        Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
      end

      # Predict log-probability for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
      # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
      def predict_log_proba(x)
        x = check_convert_sample_array(x)
        n_samples, = x.shape
        log_likelihoods = decision_function(x)
        log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(1)).reshape(n_samples, 1)
      end

      # Predict probability for samples.
      #
      # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
      # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
      def predict_proba(x)
        x = check_convert_sample_array(x)
        Numo::NMath.exp(predict_log_proba(x)).abs
      end
    end
  end
end

Version data entries

25 entries across 25 versions & 1 rubygems

Version Path
rumale-0.23.3 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.23.2 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.23.1 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.23.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.22.5 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.22.4 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.22.3 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.22.2 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.22.1 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.22.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.21.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.20.3 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.20.2 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.20.1 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.20.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.19.3 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.19.2 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.19.1 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.19.0 lib/rumale/naive_bayes/base_naive_bayes.rb
rumale-0.18.7 lib/rumale/naive_bayes/base_naive_bayes.rb