Sha256: 8f16cc937f681558df121ee039ae7d4d79c82879de509d50f012d990f558a0d8

Contents?: true

Size: 1.7 KB

Versions: 1

Compression:

Stored size: 1.7 KB

Contents

module TensorFlow
  class Variable
    attr_reader :name

    def initialize(initial_value = nil, dtype: nil, shape: nil, name: nil)
      @dtype = dtype || Utils.infer_type(Array(initial_value).flatten)
      @shape = shape
      @name = name
      @pointer = RawOps.var_handle_op(dtype: type_enum, shape: [], shared_name: Utils.default_context.shared_name)
      assign(initial_value) if initial_value
    end

    def assign(value)
      value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
      RawOps.assign_variable_op(resource: @pointer, value: value)
      self
    end

    def assign_add(value)
      value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
      RawOps.assign_add_variable_op(resource: @pointer, value: value)
      self
    end

    def assign_sub(value)
      value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
      RawOps.assign_sub_variable_op(resource: @pointer, value: value)
      self
    end

    def read_value
      RawOps.read_variable_op(resource: @pointer, dtype: type_enum)
    end

    def +(other)
      v = Variable.new(read_value.value, dtype: @dtype)
      v.assign_add(other).read_value
    end

    def -(other)
      v = Variable.new(read_value.value, dtype: @dtype)
      v.assign_sub(other).read_value
    end

    def to_s
      inspect
    end

    def shape
      read_value.shape
    end

    def inspect
      value = read_value
      inspection = %w(numo shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
      inspection.unshift("name: #{name}") if name
      "#<#{self.class} #{inspection.join(", ")}>"
    end

    def to_ptr
      read_value.to_ptr
    end

    private

    def type_enum
      FFI::DataType[@dtype.to_sym] if @dtype
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.2.0 lib/tensorflow/variable.rb