lib/tensorflow/variable.rb in tensorflow-0.1.2 vs lib/tensorflow/variable.rb in tensorflow-0.2.0

- old
+ new

@@ -1,11 +1,15 @@ module TensorFlow class Variable - def initialize(initial_value, dtype: nil) + attr_reader :name + + def initialize(initial_value = nil, dtype: nil, shape: nil, name: nil) @dtype = dtype || Utils.infer_type(Array(initial_value).flatten) + @shape = shape + @name = name @pointer = RawOps.var_handle_op(dtype: type_enum, shape: [], shared_name: Utils.default_context.shared_name) - assign(initial_value) + assign(initial_value) if initial_value end def assign(value) value = TensorFlow.convert_to_tensor(value, dtype: @dtype) RawOps.assign_variable_op(resource: @pointer, value: value) @@ -40,13 +44,22 @@ def to_s inspect end + def shape + read_value.shape + end + def inspect value = read_value inspection = %w(numo shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"} + inspection.unshift("name: #{name}") if name "#<#{self.class} #{inspection.join(", ")}>" + end + + def to_ptr + read_value.to_ptr end private def type_enum