Sha256: 006b7aa26d7e20af9d7bef6d56364341f60648569e1f08df16b46b43e6209f8b
Contents?: true
Size: 1.3 KB
Versions: 2
Compression:
Stored size: 1.3 KB
Contents
module TensorStream class Variable < Tensor attr_accessor :trainable def initialize(data_type, rank, shape, options = {}) @data_type = data_type @rank = rank @shape = TensorShape.new(shape, rank) @value = nil @source = set_source(caller_locations) @graph = options[:graph] || TensorStream.get_default_graph @name = options[:name] || build_name if options[:initializer] @initalizer_tensor = options[:initializer] end @trainable = options.fetch(:trainable, true) @graph.add_variable(self, options) end def initializer @initalizer_tensor.shape = @shape assign(@initalizer_tensor) end def assign(value) Operation.new(:assign, self, value) end def read_value @value end def assign_add(value) Operation.new(:assign_add, self, value) end def to_math(tensor, name_only = false, max_depth = 99) @name end def assign_sub(value) Operation.new(:assign_sub, self, value) end def self.variables_initializer(collection) TensorStream.group(TensorStream.get_default_graph.get_collection(collection).map(&:initializer)) end def self.global_variables_initializer variables_initializer(TensorStream::GraphKeys::GLOBAL_VARIABLES) end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensor_stream-0.1.1 | lib/tensor_stream/variable.rb |
tensor_stream-0.1.0 | lib/tensor_stream/variable.rb |