Sha256: 266397c2cb40e271f5d9a91f265452a767922f692a1d0bd171d28c6fa8d66c81

Contents?: true

Size: 1.39 KB

Versions: 1

Compression:

Stored size: 1.39 KB

Contents

module TensorFlow
  class Variable
    def initialize(initial_value, dtype: nil)
      @dtype = dtype || Utils.infer_type(initial_value)
      @pointer = TensorFlow.var_handle_op(type_enum, nil, shared_name: TensorFlow.send(:default_context).shared_name)
      assign(initial_value)
    end

    def assign(value)
      value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
      TensorFlow.assign_variable_op(@pointer, value)
      self
    end

    def assign_add(value)
      value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
      TensorFlow.assign_add_variable_op(@pointer, value)
      self
    end

    def assign_sub(value)
      value = TensorFlow.convert_to_tensor(value, dtype: @dtype)
      TensorFlow.assign_sub_variable_op(@pointer, value)
      self
    end

    def read_value
      TensorFlow.read_variable_op(@pointer, type_enum)
    end

    def +(other)
      v = Variable.new(read_value.value, dtype: @dtype)
      v.assign_add(other).read_value
    end

    def -(other)
      v = Variable.new(read_value.value, dtype: @dtype)
      v.assign_sub(other).read_value
    end

    def to_s
      inspect
    end

    def inspect
      value = read_value
      inspection = %w(value shape dtype).map { |v| "#{v}: #{value.send(v).inspect}"}
      "#<#{self.class} #{inspection.join(", ")}>"
    end

    private

    def type_enum
      FFI::DataType[@dtype.to_sym] if @dtype
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensorflow-0.1.0 lib/tensorflow/variable.rb