lib/torch.rb in torch-rb-0.1.1 vs lib/torch.rb in torch-rb-0.1.2

- old
+ new

@@ -4,10 +4,14 @@ # modules require "torch/inspector" require "torch/tensor" require "torch/version" +# optim +require "torch/optim/optimizer" +require "torch/optim/sgd" + # nn require "torch/nn/module" require "torch/nn/init" require "torch/nn/conv2d" require "torch/nn/functional" @@ -53,14 +57,18 @@ } ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h class << self # Torch.float, Torch.long, etc - DTYPE_TO_ENUM.each_key do |type| - define_method(type) do - type + DTYPE_TO_ENUM.each_key do |dtype| + define_method(dtype) do + dtype end + + Tensor.define_method(dtype) do + type(dtype) + end end # https://pytorch.org/docs/stable/torch.html def tensor?(obj) @@ -238,10 +246,22 @@ else _sum(input) end end + def argmax(input, dim = nil, keepdim: false) + if dim + _argmax_dim(input, dim, keepdim) + else + _argmax(input) + end + end + + def eq(input, other) + _eq(input, other) + end + def norm(input) _norm(input) end def pow(input, exponent) @@ -272,9 +292,13 @@ _dot(input, tensor) end def matmul(input, other) _matmul(input, other) + end + + def reshape(input, shape) + _reshape(input, shape) end private def execute_op(op, input, other, out: nil)