Sha256: 2f6a6caeb9db1561907b12b8c9c11d8fe3c0974756182bf1379d767198c5696a

Contents?: true

Size: 1.72 KB

Versions: 9

Compression:

Stored size: 1.72 KB

Contents

module TensorStream
  module OpPatch
    def self.included(klass)
      ops = if klass == Array
        {:+ => "add", :- => "sub", :* => "mul"}
      else
        {:+ => "add", :- => "sub", :/ => "div", :% => "mod", :* => "mul", :** => "pow"}
      end

      ops.each do |m, name|
        klass.send(:alias_method, :"_tensor_stream_#{name}_orig", m)
        klass.send(:remove_method, m)
      end
    end

    def +(other)
      if other.is_a?(TensorStream::Tensor)
        TensorStream.convert_to_tensor(self, dtype: other.data_type) + other
      else
        _tensor_stream_add_orig(other)
      end
    end

    def -(other)
      if other.is_a?(TensorStream::Tensor)
        TensorStream.convert_to_tensor(self, dtype: other.data_type) - other
      else
        _tensor_stream_sub_orig(other)
      end
    end

    def *(other)
      if other.is_a?(TensorStream::Tensor)
        TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
      else
        _tensor_stream_mul_orig(other)
      end
    end

    def /(other)
      if other.is_a?(TensorStream::Tensor)
        TensorStream.convert_to_tensor(self, dtype: other.data_type) * other
      else
        _tensor_stream_div_orig(other)
      end
    end

    def %(other)
      if other.is_a?(TensorStream::Tensor)
        TensorStream.convert_to_tensor(self, dtype: other.data_type) % other
      else
        _tensor_stream_mod_orig(other)
      end
    end

    def **(other)
      if other.is_a?(TensorStream::Tensor)
        TensorStream.convert_to_tensor(self, dtype: other.data_type)**other
      else
        _tensor_stream_pow_orig(other)
      end
    end
  end
end

Integer.include TensorStream::OpPatch
Float.include TensorStream::OpPatch
Array.include TensorStream::OpPatch

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.8 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.7 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.6 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.5 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.4 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.3 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.2 lib/tensor_stream/monkey_patches/op_patch.rb
tensor_stream-1.0.1 lib/tensor_stream/monkey_patches/op_patch.rb