Sha256: ca4839a36f375fa48d4a946c47b40c6a0d4e9d9f29f6d7eb8b0dfa2913187091

Contents?: true

Size: 442 Bytes

Versions: 9

Compression:

Stored size: 442 Bytes

Contents

TensorStream::OpMaker.define_operation :case do |op|
  op.exclude!

  op.define_gradient do |grad, node, params|
    n_preds = node.inputs.size - 2

    case_grads = Array.new(n_preds) { |index|
      i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad)
    }

    [nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads
  end

  op.define_shape do |tensor|
    tensor.inputs[2]&.shape&.shape
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.8 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.7 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.6 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.5 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.4 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.3 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.2 lib/tensor_stream/ops/case.rb
tensor_stream-1.0.1 lib/tensor_stream/ops/case.rb