Sha256: a6c4e367994bdb2f93c83e790ca297929747f0099e2bbb197ded8dc29350c572

Contents?: true

Size: 1.49 KB

Versions: 9

Compression:

Stored size: 1.49 KB

Contents

module TensorStream
  # Class that defines a TensorStream variable
  class Constant < Tensor
    def initialize(data_type, rank, shape, options = {})
      setup_initial_state(options)
      @data_type = data_type
      @rank = rank
      @breakpoint = false
      @shape = TensorShape.new(shape, rank)
      @value = nil
      @options = options
      @is_const = true
      @internal = options[:internal]
      @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join("/")
      @given_name = @name

      if options[:value]
        if options[:value].is_a?(Array)
          # check if single dimenstion array is passed
          options[:value] = _reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)

          @value = options[:value].map { |v| v.is_a?(Tensor) ? Tensor.cast_dtype(v, @data_type) : v }
        elsif !shape.empty?
          @value = _reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
        else
          @value = Tensor.cast_dtype(options[:value], @data_type)
        end
        @shape = TensorShape.new(shape_eval(@value))
      end

      @op = Graph.get_default_graph.add_op!(:const, value: @value, data_type: @data_type, internal_name: @name, shape: @shape)
      @name = @op.name
    end

    def inspect
      "Constant(#{@value}, name: #{@name}, shape: #{@shape}, data_type: #{@data_type})"
    end

    protected

    def build_name
      "Const"
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/constant.rb
tensor_stream-1.0.8 lib/tensor_stream/constant.rb
tensor_stream-1.0.7 lib/tensor_stream/constant.rb
tensor_stream-1.0.6 lib/tensor_stream/constant.rb
tensor_stream-1.0.5 lib/tensor_stream/constant.rb
tensor_stream-1.0.4 lib/tensor_stream/constant.rb
tensor_stream-1.0.3 lib/tensor_stream/constant.rb
tensor_stream-1.0.2 lib/tensor_stream/constant.rb
tensor_stream-1.0.1 lib/tensor_stream/constant.rb