lib/tensorflow/keras/utils.rb in tensorflow-0.1.2 vs lib/tensorflow/keras/utils.rb in tensorflow-0.2.0

- old
+ new

@@ -1,8 +1,34 @@ module TensorFlow module Keras module Utils class << self + def add_weight(name: nil, shape: [], initializer: nil, dtype: :float) + variable = Variable.new(shape: shape, name: name, dtype: dtype) + initial_value = + case initializer + when "zeros" + TensorFlow.fill(TensorFlow.convert_to_tensor(shape, dtype: :int64), 0.0) + when "glorot_uniform" + # TODO compute fans + fan_in = shape[0] + fan_out = shape[1] + scale = 1.0 + scale /= [1.0, (fan_in + fan_out) / 2.0].max + limit = ::Math.sqrt(3.0 * scale) + + minval = -limit + maxval = limit + + rnd = RawOps.random_uniform(shape: shape, dtype: :float) + Math.add(rnd * (maxval - minval), minval) + else + raise Error, "Unknown initializer: #{initializer}" + end + variable.assign(initial_value) + variable + end + def get_file(fname, origin, file_hash: nil, cache_subdir: "datasets") # destination # TODO handle this better raise "No HOME" unless ENV["HOME"] dest = "#{ENV["HOME"]}/.keras/#{cache_subdir}/#{fname}"