Sha256: a019da9885e34fa6ad4e7c9bf877816140fb32c68f4ce602efa90a76128eaf4f

Contents?: true

Size: 516 Bytes

Versions: 1

Compression:

Stored size: 516 Bytes

Contents

module TensorFlow
  module Keras
    module Metrics
      class SparseCategoricalAccuracy < Mean
        def update_state(y_true, y_pred)
          y_true = TensorFlow.convert_to_tensor(y_true)
          y_pred = TensorFlow.convert_to_tensor(y_pred)

          y_pred = RawOps.arg_max(input: y_pred, dimension: -1)

          if y_pred.dtype != y_true.dtype
            y_pred = TensorFlow.cast(y_pred, y_true.dtype)
          end

          super(Math.equal(y_true, y_pred))
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb