Sha256: fed307b845ed7a1384515f79c465d021e3e8aeed5d7b52145d951aae3c8f3080

Contents?: true

Size: 1.91 KB

Versions: 17

Compression:

Stored size: 1.91 KB

Contents

# frozen_string_literal: true

require 'svmkit/validation'
require 'svmkit/base/evaluator'
require 'svmkit/preprocessing/one_hot_encoder'

module SVMKit
  module EvaluationMeasure
    # LogLoss is a class that calculates the logarithmic loss of predicted class probability.
    #
    # @example
    #   evaluator = SVMKit::EvaluationMeasure::LogLoss.new
    #   puts evaluator.score(ground_truth, predicted)
    class LogLoss
      include Base::Evaluator

      # Calculate mean logarithmic loss.
      # If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
      # mean logarithmic loss for binary classification.
      #
      # @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
      # @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
      # @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
      # @return [Float] mean logarithmic loss
      def score(y_true, y_pred, eps = 1e-15)
        SVMKit::Validation.check_params_type(Numo::Int32, y_true: y_true)
        SVMKit::Validation.check_params_type(Numo::DFloat, y_pred: y_pred)

        n_samples, n_classes = y_pred.shape
        clipped_p = y_pred.clip(eps, 1 - eps)

        log_loss = if n_classes.nil?
                     negative_label = y_true.to_a.uniq.min
                     bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
                     -(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
                   else
                     encoder = SVMKit::Preprocessing::OneHotEncoder.new
                     encoded_y_true = encoder.fit_transform(y_true)
                     clipped_p /= clipped_p.sum(1).expand_dims(1)
                     -(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
                   end
        log_loss.sum / n_samples
      end
    end
  end
end

Version data entries

17 entries across 17 versions & 1 rubygems

Version Path
svmkit-0.7.3 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.7.2 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.7.1 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.7.0 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.6.3 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.6.2 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.6.1 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.6.0 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.5.2 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.5.1 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.5.0 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.4.1 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.4.0 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.3.3 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.3.2 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.3.1 lib/svmkit/evaluation_measure/log_loss.rb
svmkit-0.3.0 lib/svmkit/evaluation_measure/log_loss.rb