Sha256: 08d4f05475d2274758e4dcf9ba7e237e61ffdb0707d405ada04461545512036d

Contents?: true

Size: 721 Bytes

Versions: 9

Compression:

Stored size: 721 Bytes

Contents

TensorStream::OpMaker.define_operation :sub do |op|
  op.other_names %w(subtract)
  op.what_it_does "Returns x - y element-wise."

  op.parameter :input_a, "tensor X"
  op.parameter :input_b, "tensor Y"

  op.apply_data_type_coercion!
  op.supports_broadcasting!

  op.option :name, "Optional name", :nil

  op.define_gradient do |grad, node, params|
    x, y = params
    next [grad, -grad] if shapes_fully_specified_and_equal(x, y)

    sx = ts.shape(x, name: "sub/shape_x")
    sy = ts.shape(y, name: "sub/shape_y")
    rx, ry = _broadcast_gradient_args(sx, sy)

    [ts.reshape(ts.reduce_sum(grad, rx, name: "add/reduce_sub_x"), sx),
     -ts.reshape(ts.reduce_sum(grad, ry, name: "add/reduce_sub_y"), sy),]
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.8 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.7 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.6 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.5 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.4 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.3 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.2 lib/tensor_stream/ops/sub.rb
tensor_stream-1.0.1 lib/tensor_stream/ops/sub.rb