Sha256: ca4839a36f375fa48d4a946c47b40c6a0d4e9d9f29f6d7eb8b0dfa2913187091
Contents?: true
Size: 442 Bytes
Versions: 9
Compression:
Stored size: 442 Bytes
Contents
TensorStream::OpMaker.define_operation :case do |op| op.exclude! op.define_gradient do |grad, node, params| n_preds = node.inputs.size - 2 case_grads = Array.new(n_preds) { |index| i_op(:case_grad, index, node.inputs[0], node.inputs[2 + index], grad) } [nil, i_op(:case_grad, -1, node.inputs[0], node.inputs[1], grad)] + case_grads end op.define_shape do |tensor| tensor.inputs[2]&.shape&.shape end end
Version data entries
9 entries across 9 versions & 1 rubygems