lib/tensorflow/keras/utils.rb in tensorflow-0.1.2 vs lib/tensorflow/keras/utils.rb in tensorflow-0.2.0
- old
+ new
@@ -1,8 +1,34 @@
module TensorFlow
module Keras
module Utils
class << self
+ def add_weight(name: nil, shape: [], initializer: nil, dtype: :float)
+ variable = Variable.new(shape: shape, name: name, dtype: dtype)
+ initial_value =
+ case initializer
+ when "zeros"
+ TensorFlow.fill(TensorFlow.convert_to_tensor(shape, dtype: :int64), 0.0)
+ when "glorot_uniform"
+ # TODO compute fans
+ fan_in = shape[0]
+ fan_out = shape[1]
+ scale = 1.0
+ scale /= [1.0, (fan_in + fan_out) / 2.0].max
+ limit = ::Math.sqrt(3.0 * scale)
+
+ minval = -limit
+ maxval = limit
+
+ rnd = RawOps.random_uniform(shape: shape, dtype: :float)
+ Math.add(rnd * (maxval - minval), minval)
+ else
+ raise Error, "Unknown initializer: #{initializer}"
+ end
+ variable.assign(initial_value)
+ variable
+ end
+
def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets")
# destination
# TODO handle this better
raise "No HOME" unless ENV["HOME"]
dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}"