lib/tensorflow/keras/metrics/mean.rb in tensorflow-0.1.2 vs lib/tensorflow/keras/metrics/mean.rb in tensorflow-0.2.0

- old
+ new

@@ -1,16 +1,30 @@ module TensorFlow module Keras module Metrics class Mean + def initialize(name: nil, dtype: :float) + @dtype = dtype + @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype) + @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype) + end + + def call(*args) + update_state(*args) + end + def update_state(values) - input = TensorFlow.convert_to_tensor(values, dtype: :float) - @total = Math.reduce_sum(input) - @count = RawOps.size(input: input) + input = TensorFlow.convert_to_tensor(values) + input = TensorFlow.cast(input, @dtype) + @total.assign_add(Math.reduce_sum(input)) + @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype)) end def result RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float)) + end + + def reset_states end end end end end