Sha256: 5a96b336600dd0a24b4184aafcf6639e9e7c81d284b6d8c73c76cb4e4ebf3f5e

Contents?: true

Size: 325 Bytes

Versions: 9

Compression:

Stored size: 325 Bytes

Contents

TensorStream::OpMaker.define_operation :cast do |op|
  op.exclude!

  op.define_gradient do |grad, node, params|
    t = %i[float16 float32 float64]
    src_type = node.inputs[0].data_type
    dst_type = grad.data_type

    if t.key?(src_type) && t.key?(dst_type)
      next ts.cast(grad, src_type)
    end

    nil
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.8 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.7 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.6 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.5 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.4 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.3 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.2 lib/tensor_stream/ops/cast.rb
tensor_stream-1.0.1 lib/tensor_stream/ops/cast.rb