lib/torch.rb in torch-rb-0.1.1 vs lib/torch.rb in torch-rb-0.1.2
- old
+ new
@@ -4,10 +4,14 @@
# modules
require "torch/inspector"
require "torch/tensor"
require "torch/version"
+# optim
+require "torch/optim/optimizer"
+require "torch/optim/sgd"
+
# nn
require "torch/nn/module"
require "torch/nn/init"
require "torch/nn/conv2d"
require "torch/nn/functional"
@@ -53,14 +57,18 @@
}
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
class << self
# Torch.float, Torch.long, etc
- DTYPE_TO_ENUM.each_key do |type|
- define_method(type) do
- type
+ DTYPE_TO_ENUM.each_key do |dtype|
+ define_method(dtype) do
+ dtype
end
+
+ Tensor.define_method(dtype) do
+ type(dtype)
+ end
end
# https://pytorch.org/docs/stable/torch.html
def tensor?(obj)
@@ -238,10 +246,22 @@
else
_sum(input)
end
end
+ def argmax(input, dim = nil, keepdim: false)
+ if dim
+ _argmax_dim(input, dim, keepdim)
+ else
+ _argmax(input)
+ end
+ end
+
+ def eq(input, other)
+ _eq(input, other)
+ end
+
def norm(input)
_norm(input)
end
def pow(input, exponent)
@@ -272,9 +292,13 @@
_dot(input, tensor)
end
def matmul(input, other)
_matmul(input, other)
+ end
+
+ def reshape(input, shape)
+ _reshape(input, shape)
end
private
def execute_op(op, input, other, out: nil)