Sha256: b4a1a79595d2461ee1beb8b49610c316bf816a40a7f2f14b209c1e6aaf6ce12b
Contents?: true
Size: 853 Bytes
Versions: 1
Compression:
Stored size: 853 Bytes
Contents
module TensorFlow module Keras module Metrics class Mean def initialize(name: nil, dtype: :float) @dtype = dtype @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype) @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype) end def call(*args) update_state(*args) end def update_state(values) input = TensorFlow.convert_to_tensor(values) input = TensorFlow.cast(input, @dtype) @total.assign_add(Math.reduce_sum(input)) @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype)) end def result RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float)) end def reset_states end end end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.2.0 | lib/tensorflow/keras/metrics/mean.rb |