Sha256: 266397c2cb40e271f5d9a91f265452a767922f692a1d0bd171d28c6fa8d66c81
Contents?: true
Size: 1.39 KB
Versions: 1
Compression:
Stored size: 1.39 KB
Contents
module TensorFlow class Variable def initialize(initial_value, dtype: nil) @dtype = dtype || Utils.infer_type(initial_value) @pointer = TensorFlow.var_handle_op(type_enum, nil, shared_name: TensorFlow.send(:default_context).shared_name) assign(initial_value) end def assign(value) value = TensorFlow.convert_to_tensor(value, dtype: @dtype) TensorFlow.assign_variable_op(@pointer, value) self end def assign_add(value) value = TensorFlow.convert_to_tensor(value, dtype: @dtype) TensorFlow.assign_add_variable_op(@pointer, value) self end def assign_sub(value) value = TensorFlow.convert_to_tensor(value, dtype: @dtype) TensorFlow.assign_sub_variable_op(@pointer, value) self end def read_value TensorFlow.read_variable_op(@pointer, type_enum) end def +(other) v = Variable.new(read_value.value, dtype: @dtype) v.assign_add(other).read_value end def -(other) v = Variable.new(read_value.value, dtype: @dtype) v.assign_sub(other).read_value end def to_s inspect end def inspect value = read_value inspection = %w(value shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"} "#<#{self.class} #{inspection.join(", ")}>" end private def type_enum FFI::DataType[@dtype.to_sym] if @dtype end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
tensorflow-0.1.0 | lib/tensorflow/variable.rb |