lib/tensorflow/variable.rb in tensorflow-0.1.2 vs lib/tensorflow/variable.rb in tensorflow-0.2.0
- old
+ new
@@ -1,11 +1,15 @@
module TensorFlow
class Variable
- def initialize(initial_value, dtype: nil)
+ attr_reader :name
+
+ def initialize(initial_value = nil, dtype: nil, shape: nil, name: nil)
@dtype = dtype || Utils.infer_type(Array(initial_value).flatten)
+ @shape = shape
+ @name = name
@pointer = RawOps.var_handle_op(dtype: type_enum, shape: [], shared_name: Utils.default_context.shared_name)
- assign(initial_value)
+ assign(initial_value) if initial_value
end
def assign(value)
value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
RawOps.assign_variable_op(resource: @pointer, value: value)
@@ -40,13 +44,22 @@
def to_s
inspect
end
+ def shape
+ read_value.shape
+ end
+
def inspect
value = read_value
inspection = %w(numo shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
+ inspection.unshift("name: #{name}") if name
"#<#{self.class} #{inspection.join(", ")}>"
+ end
+
+ def to_ptr
+ read_value.to_ptr
end
private
def type_enum