Sha256: 63f8c13033173b6ee866d17a470f392d4e76efd662cd6e1f42064a8fdde9a50a

Contents?: true

Size: 1.49 KB

Versions: 1

Compression:

Stored size: 1.49 KB

Contents

module TensorStream
  # Class that defines a TensorStream variable
  class Constant < Tensor
    def initialize(data_type, rank, shape, options = {})
      setup_initial_state(options)
      @data_type = data_type
      @rank = rank
      @breakpoint = false
      @shape = TensorShape.new(shape, rank)
      @value = nil
      @options = options
      @is_const = true
      @internal = options[:internal]
      @name = [@graph.get_name_scope, options[:name] || build_name].compact.reject(&:empty?).join('/')
      @given_name = @name

      if options[:value]
        if options[:value].is_a?(Array)
          # check if single dimenstion array is passed
          options[:value] = _reshape(options[:value], shape.reverse.dup) if shape.size >= 2 && !options[:value].empty? && !options[:value][0].is_a?(Array)

          @value = options[:value].map { |v| v.is_a?(Tensor) ? Tensor.cast_dtype(v, @data_type) : v }
        elsif !shape.empty?
          @value = _reshape(Tensor.cast_dtype(options[:value], @data_type), shape.dup)
        else
          @value = Tensor.cast_dtype(options[:value], @data_type)
        end
        @shape = TensorShape.new(shape_eval(@value))
      end

      @op = Graph.get_default_graph.add_op!(:const, value: @value, data_type: @data_type, internal_name: @name, shape: @shape)
      @name = @op.name
    end

    def inspect
      "Constant(#{@value}, name: #{@name}, shape: #{@shape}, data_type: #{@data_type})"
    end

    protected

    def build_name
      "Const"
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensor_stream-1.0.0 lib/tensor_stream/constant.rb