lib/torch.rb in torch-rb-0.1.2 vs lib/torch.rb in torch-rb-0.1.3

- old
+ new

@@ -6,29 +6,56 @@ require "torch/tensor" require "torch/version" # optim require "torch/optim/optimizer" +require "torch/optim/adadelta" +require "torch/optim/adagrad" +require "torch/optim/adam" +require "torch/optim/adamax" +require "torch/optim/adamw" +require "torch/optim/asgd" +require "torch/optim/rmsprop" +require "torch/optim/rprop" require "torch/optim/sgd" -# nn +# optim lr_scheduler +require "torch/optim/lr_scheduler/lr_scheduler" +require "torch/optim/lr_scheduler/step_lr" + +# nn base classes require "torch/nn/module" -require "torch/nn/init" +require "torch/nn/convnd" +require "torch/nn/dropoutnd" + +# nn +require "torch/nn/alpha_dropout" require "torch/nn/conv2d" +require "torch/nn/dropout" +require "torch/nn/dropout2d" +require "torch/nn/dropout3d" +require "torch/nn/embedding" +require "torch/nn/feature_alpha_dropout" require "torch/nn/functional" +require "torch/nn/init" require "torch/nn/linear" +require "torch/nn/mse_loss" require "torch/nn/parameter" -require "torch/nn/sequential" require "torch/nn/relu" -require "torch/nn/mse_loss" +require "torch/nn/sequential" # utils require "torch/utils/data/data_loader" require "torch/utils/data/tensor_dataset" module Torch class Error < StandardError; end + class NotImplementedYet < StandardError + def message + "This feature has not been implemented yet. Consider submitting a PR." + end + end # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h # complex and quantized types not supported by PyTorch yet DTYPE_TO_ENUM = { @@ -73,15 +100,22 @@ def tensor?(obj) obj.is_a?(Tensor) end - # TODO don't copy def from_numo(ndarray) dtype = _dtype_to_numo.find { |k, v| ndarray.is_a?(v) } raise Error, "Cannot convert #{ndarray.class.name} to tensor" unless dtype - tensor(ndarray.to_a, dtype: dtype[0]) + options = tensor_options(device: "cpu", dtype: dtype[0]) + # TODO pass pointer to array instead of creating string + str = ndarray.to_string + tensor = _from_blob(str, ndarray.shape, options) + # from_blob does not own the data, so we need to keep + # a reference to it for duration of tensor + # can remove when passing pointer directly + tensor.instance_variable_set("@_numo_str", str) + tensor end # private # use method for cases when Numo not available # or available after Torch loaded @@ -195,11 +229,11 @@ # ruby doesn't support input, low = 0, high, ... if high.nil? high = low low = 0 end - rand(input.size, like_options(input, options)) + randint(low, high, input.size, like_options(input, options)) end def randn_like(input, **options) randn(input.size, like_options(input, options)) end @@ -270,35 +304,72 @@ def min(input) _min(input) end - def max(input) - _max(input) + def max(input, dim = nil, keepdim: false, out: nil) + if dim + raise NotImplementedYet unless out + _max_out(out[0], out[1], input, dim, keepdim) + else + _max(input) + end end def exp(input) _exp(input) end def log(input) _log(input) end + def sign(input) + _sign(input) + end + + def gt(input, other) + _gt(input, other) + end + + def lt(input, other) + _lt(input, other) + end + def unsqueeze(input, dim) _unsqueeze(input, dim) end def dot(input, tensor) _dot(input, tensor) end + def cat(tensors, dim = 0) + _cat(tensors, dim) + end + def matmul(input, other) _matmul(input, other) end def reshape(input, shape) _reshape(input, shape) + end + + def flatten(input, start_dim: 0, end_dim: -1) + _flatten(input, start_dim, end_dim) + end + + def sqrt(input) + _sqrt(input) + end + + def abs(input) + _abs(input) + end + + def device(str) + Device.new(str) end private def execute_op(op, input, other, out: nil)