Sha256: a019da9885e34fa6ad4e7c9bf877816140fb32c68f4ce602efa90a76128eaf4f
Contents?: true
Size: 516 Bytes
Versions: 1
Compression:
Stored size: 516 Bytes
Contents
module TensorFlow module Keras module Metrics class SparseCategoricalAccuracy < Mean def update_state(y_true, y_pred) y_true = TensorFlow.convert_to_tensor(y_true) y_pred = TensorFlow.convert_to_tensor(y_pred) y_pred = RawOps.arg_max(input: y_pred, dimension: -1) if y_pred.dtype != y_true.dtype y_pred = TensorFlow.cast(y_pred, y_true.dtype) end super(Math.equal(y_true, y_pred)) end end end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.2.0 | lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb |