lib/torch.rb in torch-rb-0.1.2 vs lib/torch.rb in torch-rb-0.1.3
- old
+ new
@@ -6,29 +6,56 @@
require "torch/tensor"
require "torch/version"
# optim
require "torch/optim/optimizer"
+require "torch/optim/adadelta"
+require "torch/optim/adagrad"
+require "torch/optim/adam"
+require "torch/optim/adamax"
+require "torch/optim/adamw"
+require "torch/optim/asgd"
+require "torch/optim/rmsprop"
+require "torch/optim/rprop"
require "torch/optim/sgd"
-# nn
+# optim lr_scheduler
+require "torch/optim/lr_scheduler/lr_scheduler"
+require "torch/optim/lr_scheduler/step_lr"
+
+# nn base classes
require "torch/nn/module"
-require "torch/nn/init"
+require "torch/nn/convnd"
+require "torch/nn/dropoutnd"
+
+# nn
+require "torch/nn/alpha_dropout"
require "torch/nn/conv2d"
+require "torch/nn/dropout"
+require "torch/nn/dropout2d"
+require "torch/nn/dropout3d"
+require "torch/nn/embedding"
+require "torch/nn/feature_alpha_dropout"
require "torch/nn/functional"
+require "torch/nn/init"
require "torch/nn/linear"
+require "torch/nn/mse_loss"
require "torch/nn/parameter"
-require "torch/nn/sequential"
require "torch/nn/relu"
-require "torch/nn/mse_loss"
+require "torch/nn/sequential"
# utils
require "torch/utils/data/data_loader"
require "torch/utils/data/tensor_dataset"
module Torch
class Error < StandardError; end
+ class NotImplementedYet < StandardError
+ def message
+ "This feature has not been implemented yet. Consider submitting a PR."
+ end
+ end
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
# complex and quantized types not supported by PyTorch yet
DTYPE_TO_ENUM = {
@@ -73,15 +100,22 @@
def tensor?(obj)
obj.is_a?(Tensor)
end
- # TODO don't copy
def from_numo(ndarray)
dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) }
raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype
- tensor(ndarray.to_a, dtype: dtype[0])
+ options = tensor_options(device: "cpu", dtype: dtype[0])
+ # TODO pass pointer to array instead of creating string
+ str = ndarray.to_string
+ tensor = _from_blob(str, ndarray.shape, options)
+ # from_blob does not own the data, so we need to keep
+ # a reference to it for duration of tensor
+ # can remove when passing pointer directly
+ tensor.instance_variable_set("@_numo_str", str)
+ tensor
end
# private
# use method for cases when Numo not available
# or available after Torch loaded
@@ -195,11 +229,11 @@
# ruby doesn't support input, low = 0, high, ...
if high.nil?
high = low
low = 0
end
- rand(input.size, like_options(input, options))
+ randint(low, high, input.size, like_options(input, options))
end
def randn_like(input, **options)
randn(input.size, like_options(input, options))
end
@@ -270,35 +304,72 @@
def min(input)
_min(input)
end
- def max(input)
- _max(input)
+ def max(input, dim = nil, keepdim: false, out: nil)
+ if dim
+ raise NotImplementedYet unless out
+ _max_out(out[0], out[1], input, dim, keepdim)
+ else
+ _max(input)
+ end
end
def exp(input)
_exp(input)
end
def log(input)
_log(input)
end
+ def sign(input)
+ _sign(input)
+ end
+
+ def gt(input, other)
+ _gt(input, other)
+ end
+
+ def lt(input, other)
+ _lt(input, other)
+ end
+
def unsqueeze(input, dim)
_unsqueeze(input, dim)
end
def dot(input, tensor)
_dot(input, tensor)
end
+ def cat(tensors, dim = 0)
+ _cat(tensors, dim)
+ end
+
def matmul(input, other)
_matmul(input, other)
end
def reshape(input, shape)
_reshape(input, shape)
+ end
+
+ def flatten(input, start_dim: 0, end_dim: -1)
+ _flatten(input, start_dim, end_dim)
+ end
+
+ def sqrt(input)
+ _sqrt(input)
+ end
+
+ def abs(input)
+ _abs(input)
+ end
+
+ def device(str)
+ Device.new(str)
end
private
def execute_op(op, input, other, out: nil)