Sha256: 88afc9d3d2e857f36b8d464383302c21c80c8cd088656debccceb4057025def1

Contents?: true

Size: 930 Bytes

Versions: 9

Compression:

Stored size: 930 Bytes

Contents

TensorStream::OpMaker.define_operation :reshape do |op|
  op.what_it_does "Reshapes a tensor."
  op.what_it_does "Given tensor, this operation returns a tensor that has the same values as tensor with shape shape."

  op.parameter :input, "A tensor"
  op.parameter :shape, "A new tensor shape"
  op.option :name, "Optional name", :nil

  op.define_gradient do |grad, node, params|
    [ts.reshape(grad, ts.shape(node.inputs[0])), nil]
  end

  op.define_shape do |tensor|
    new_shape = tensor.inputs[1]&.const_value ? tensor.inputs[1].const_value : nil
    next nil if new_shape.nil?
    next nil if tensor.inputs[0].shape.nil?

    input_shape = tensor.inputs[0].shape.shape
    next new_shape if input_shape.nil? && !new_shape.include?(-1) && !new_shape.include?(nil)
    next nil if input_shape.nil? || input_shape.include?(nil)

    TensorStream::TensorShape.fix_inferred_elements(new_shape, input_shape.reduce(:*))
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.8 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.7 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.6 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.5 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.4 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.3 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.2 lib/tensor_stream/ops/reshape.rb
tensor_stream-1.0.1 lib/tensor_stream/ops/reshape.rb