Sha256: e3aaa9bd6877c8f123b8a16aa220f0e0f4504414ab896c5cce0ec92ab1d10eb7
Contents?: true
Size: 1.34 KB
Versions: 3
Compression:
Stored size: 1.34 KB
Contents
module TensorStream # Class that defines a TensorStream variable class Variable < Tensor attr_accessor :trainable def initialize(data_type, rank, shape, options = {}) @graph = options[:graph] || TensorStream.get_default_graph @data_type = data_type @rank = rank @shape = TensorShape.new(shape, rank) @value = nil @source = format_source(caller_locations) @name = options[:name] || build_name @initalizer_tensor = options[:initializer] if options[:initializer] @trainable = options.fetch(:trainable, true) @graph.add_variable(self, options) end def initializer @initalizer_tensor.shape = @shape assign(@initalizer_tensor) end def assign(value) Operation.new(:assign, self, value) end def read_value @value end def assign_add(value) Operation.new(:assign_add, self, value) end def to_math(_tensor, _name_only = false, _max_depth = 99) @name end def assign_sub(value) Operation.new(:assign_sub, self, value) end def self.variables_initializer(collection) TensorStream.group(TensorStream.get_default_graph.get_collection(collection).map(&:initializer)) end def self.global_variables_initializer variables_initializer(TensorStream::GraphKeys::GLOBAL_VARIABLES) end end end
Version data entries
3 entries across 3 versions & 1 rubygems
Version | Path |
---|---|
tensor_stream-0.1.4 | lib/tensor_stream/variable.rb |
tensor_stream-0.1.3 | lib/tensor_stream/variable.rb |
tensor_stream-0.1.2 | lib/tensor_stream/variable.rb |