Sha256: b4a1a79595d2461ee1beb8b49610c316bf816a40a7f2f14b209c1e6aaf6ce12b

Contents?: true

Size: 853 Bytes

Versions: 1

Compression:

Stored size: 853 Bytes

Contents

module TensorFlow
  module Keras
    module Metrics
      class Mean
        def initialize(name: nil, dtype: :float)
          @dtype = dtype
          @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
          @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
        end

        def call(*args)
          update_state(*args)
        end

        def update_state(values)
          input = TensorFlow.convert_to_tensor(values)
          input = TensorFlow.cast(input, @dtype)
          @total.assign_add(Math.reduce_sum(input))
          @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype))
        end

        def result
          RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float))
        end

        def reset_states
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/keras/metrics/mean.rb