lib/torch.rb in torch-rb-0.1.3 vs lib/torch.rb in torch-rb-0.1.4

- old
+ new

@@ -20,46 +20,104 @@ # optim lr_scheduler require "torch/optim/lr_scheduler/lr_scheduler" require "torch/optim/lr_scheduler/step_lr" -# nn base classes +# nn parameters +require "torch/nn/parameter" + +# nn containers require "torch/nn/module" +require "torch/nn/sequential" + +# nn convolution layers require "torch/nn/convnd" -require "torch/nn/dropoutnd" +require "torch/nn/conv2d" -# nn +# nn pooling layers +require "torch/nn/max_poolnd" +require "torch/nn/max_pool2d" +require "torch/nn/avg_poolnd" +require "torch/nn/avg_pool2d" + +# nn linear layers +require "torch/nn/bilinear" +require "torch/nn/identity" +require "torch/nn/linear" + +# nn dropout layers +require "torch/nn/dropoutnd" require "torch/nn/alpha_dropout" -require "torch/nn/conv2d" require "torch/nn/dropout" require "torch/nn/dropout2d" require "torch/nn/dropout3d" -require "torch/nn/embedding" require "torch/nn/feature_alpha_dropout" + +# nn activations +require "torch/nn/leaky_relu" +require "torch/nn/prelu" +require "torch/nn/relu" +require "torch/nn/sigmoid" +require "torch/nn/softplus" + +# nn activations other +require "torch/nn/log_softmax" +require "torch/nn/softmax" +require "torch/nn/softmax2d" +require "torch/nn/softmin" + +# nn sparse layers +require "torch/nn/embedding" +require "torch/nn/embedding_bag" + +# nn distance functions +require "torch/nn/cosine_similarity" +require "torch/nn/pairwise_distance" + +# nn loss functions +require "torch/nn/loss" +require "torch/nn/weighted_loss" +require "torch/nn/bce_loss" +# require "torch/nn/bce_with_logits_loss" +# require "torch/nn/cosine_embedding_loss" +require "torch/nn/cross_entropy_loss" +require "torch/nn/ctc_loss" +# require "torch/nn/hinge_embedding_loss" +require "torch/nn/kl_div_loss" +require "torch/nn/l1_loss" +# require "torch/nn/margin_ranking_loss" +require "torch/nn/mse_loss" +# require "torch/nn/multi_label_margin_loss" +# require "torch/nn/multi_label_soft_margin_loss" +# require "torch/nn/multi_margin_loss" +require "torch/nn/nll_loss" +require "torch/nn/poisson_nll_loss" +# require "torch/nn/smooth_l1_loss" +# require "torch/nn/soft_margin_loss" +# require "torch/nn/triplet_margin_loss" + +# nn other require "torch/nn/functional" require "torch/nn/init" -require "torch/nn/linear" -require "torch/nn/mse_loss" -require "torch/nn/parameter" -require "torch/nn/relu" -require "torch/nn/sequential" # utils require "torch/utils/data/data_loader" require "torch/utils/data/tensor_dataset" +# random +require "torch/random" + module Torch class Error < StandardError; end class NotImplementedYet < StandardError def message "This feature has not been implemented yet. Consider submitting a PR." end end # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h - # complex and quantized types not supported by PyTorch yet DTYPE_TO_ENUM = { uint8: 0, int8: 1, short: 2, int16: 2, @@ -71,18 +129,18 @@ float16: 5, float: 6, float32: 6, double: 7, float64: 7, - # complex_half: 8, - # complex_float: 9, - # complex_double: 10, + complex_half: 8, + complex_float: 9, + complex_double: 10, bool: 11, - # qint8: 12, - # quint8: 13, - # qint32: 14, - # bfloat16: 15 + qint8: 12, + quint8: 13, + qint32: 14, + bfloat16: 15 } ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h class << self # Torch.float, Torch.long, etc @@ -118,10 +176,12 @@ # private # use method for cases when Numo not available # or available after Torch loaded def _dtype_to_numo + raise Error, "Numo not found" unless defined?(Numo::NArray) + { uint8: Numo::UInt8, int8: Numo::Int8, int16: Numo::Int16, int32: Numo::Int32, @@ -198,12 +258,16 @@ data = data.flatten else data = [data].compact end - if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) } - options[:dtype] = :int64 + if options[:dtype].nil? + if data.all? { |v| v.is_a?(Integer) } + options[:dtype] = :int64 + elsif data.all? { |v| v == true || v == false } + options[:dtype] = :bool + end end _tensor(data, size, tensor_options(**options)) end @@ -300,10 +364,14 @@ def pow(input, exponent) _pow(input, exponent) end + def topk(input, k) + _topk(input, k) + end + def min(input) _min(input) end def max(input, dim = nil, keepdim: false, out: nil) @@ -325,10 +393,14 @@ def sign(input) _sign(input) end + def sigmoid(input) + _sigmoid(input) + end + def gt(input, other) _gt(input, other) end def lt(input, other) @@ -359,9 +431,18 @@ _flatten(input, start_dim, end_dim) end def sqrt(input) _sqrt(input) + end + + # TODO make dim keyword argument + def log_softmax(input, dim) + _log_softmax(input, dim) + end + + def softmax(input, dim: nil) + _softmax(input, dim) end def abs(input) _abs(input) end