Sha256: ce09cc7dda6e68b830ba2d955483bb02be0217f7200ba59e97b856bac2b3770a

Contents?: true

Size: 1.62 KB

Versions: 12

Compression:

Stored size: 1.62 KB

Contents

module Eps
  module Metrics
    class << self
      def rmse(y_true, y_pred, weight: nil)
        check_size(y_true, y_pred)
        Math.sqrt(mean(errors(y_true, y_pred).map { |v| v**2 }, weight: weight))
      end

      def mae(y_true, y_pred, weight: nil)
        check_size(y_true, y_pred)
        mean(errors(y_true, y_pred).map { |v| v.abs }, weight: weight)
      end

      def me(y_true, y_pred, weight: nil)
        check_size(y_true, y_pred)
        mean(errors(y_true, y_pred), weight: weight)
      end

      def accuracy(y_true, y_pred, weight: nil)
        check_size(y_true, y_pred)
        values = y_true.zip(y_pred).map { |yt, yp| yt == yp ? 1 : 0 }
        if weight
          values.each_with_index do |v, i|
            values[i] *= weight[i]
          end
          values.sum / weight.sum.to_f
        else
          values.sum / y_true.size.to_f
        end
      end

      # http://wiki.fast.ai/index.php/Log_Loss
      def log_loss(y_true, y_pred, eps: 1e-15, weight: nil)
        check_size(y_true, y_pred)
        p = y_pred.map { |yp| yp.clamp(eps, 1 - eps) }
        mean(y_true.zip(p).map { |yt, pi| yt == 1 ? -Math.log(pi) : -Math.log(1 - pi) }, weight: weight)
      end

      private

      def check_size(y_true, y_pred)
        raise ArgumentError, "Different sizes" if y_true.size != y_pred.size
      end

      def mean(arr, weight: nil)
        if weight
          arr.map.with_index { |v, i| v * weight[i] }.sum / weight.sum.to_f
        else
          arr.sum / arr.size.to_f
        end
      end

      def errors(y_true, y_pred)
        y_true.zip(y_pred).map { |yt, yp| yt - yp }
      end
    end
  end
end

Version data entries

12 entries across 12 versions & 1 rubygems

Version Path
eps-0.5.0 lib/eps/metrics.rb
eps-0.4.1 lib/eps/metrics.rb
eps-0.4.0 lib/eps/metrics.rb
eps-0.3.9 lib/eps/metrics.rb
eps-0.3.8 lib/eps/metrics.rb
eps-0.3.7 lib/eps/metrics.rb
eps-0.3.6 lib/eps/metrics.rb
eps-0.3.5 lib/eps/metrics.rb
eps-0.3.4 lib/eps/metrics.rb
eps-0.3.3 lib/eps/metrics.rb
eps-0.3.2 lib/eps/metrics.rb
eps-0.3.1 lib/eps/metrics.rb