lib/tensorflow/keras/metrics/mean.rb in tensorflow-0.1.2 vs lib/tensorflow/keras/metrics/mean.rb in tensorflow-0.2.0
- old
+ new
@@ -1,16 +1,30 @@
module TensorFlow
module Keras
module Metrics
class Mean
+ def initialize(name: nil, dtype: :float)
+ @dtype = dtype
+ @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
+ @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
+ end
+
+ def call(*args)
+ update_state(*args)
+ end
+
def update_state(values)
- input = TensorFlow.convert_to_tensor(values, dtype: :float)
- @total = Math.reduce_sum(input)
- @count = RawOps.size(input: input)
+ input = TensorFlow.convert_to_tensor(values)
+ input = TensorFlow.cast(input, @dtype)
+ @total.assign_add(Math.reduce_sum(input))
+ @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype))
end
def result
RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float))
+ end
+
+ def reset_states
end
end
end
end
end