Sha256: c9da71aeecbd2f661abf1ba9e4ef21195c80171b2b55123a2e4152baebcc1f5b

Contents?: true

Size: 773 Bytes

Versions: 9

Compression:

Stored size: 773 Bytes

Contents

module TensorStream
  # Defines a TensorStream controlflow op
  class DynamicStitch < Operation
    attr_accessor :ops

    def initialize(flow_type, inputs, ops = nil, options = {})
      setup_initial_state(options)

      @operation = :"flow_#{flow_type}"
      @options = options.merge(n: inputs[0].size)
      @inputs = inputs.flatten(1).map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }

      @consumers = Set.new
      @data_type = Tensor.detect_type(inputs[1])
      @name = [@graph.get_name_scope, options[:name] || set_name].compact.join("/")
      @ops = ops
      @shape = TensorShape.new(nil)
      @graph.add_node(self)
    end

    def set_data_type(_passed_data_type)
      :unknown
    end

    def run
      eval
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.8 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.7 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.6 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.5 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.4 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.3 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.2 lib/tensor_stream/dynamic_stitch.rb
tensor_stream-1.0.1 lib/tensor_stream/dynamic_stitch.rb