lib/torch.rb in torch-rb-0.1.3 vs lib/torch.rb in torch-rb-0.1.4
- old
+ new
@@ -20,46 +20,104 @@
# optim lr_scheduler
require "torch/optim/lr_scheduler/lr_scheduler"
require "torch/optim/lr_scheduler/step_lr"
-# nn base classes
+# nn parameters
+require "torch/nn/parameter"
+
+# nn containers
require "torch/nn/module"
+require "torch/nn/sequential"
+
+# nn convolution layers
require "torch/nn/convnd"
-require "torch/nn/dropoutnd"
+require "torch/nn/conv2d"
-# nn
+# nn pooling layers
+require "torch/nn/max_poolnd"
+require "torch/nn/max_pool2d"
+require "torch/nn/avg_poolnd"
+require "torch/nn/avg_pool2d"
+
+# nn linear layers
+require "torch/nn/bilinear"
+require "torch/nn/identity"
+require "torch/nn/linear"
+
+# nn dropout layers
+require "torch/nn/dropoutnd"
require "torch/nn/alpha_dropout"
-require "torch/nn/conv2d"
require "torch/nn/dropout"
require "torch/nn/dropout2d"
require "torch/nn/dropout3d"
-require "torch/nn/embedding"
require "torch/nn/feature_alpha_dropout"
+
+# nn activations
+require "torch/nn/leaky_relu"
+require "torch/nn/prelu"
+require "torch/nn/relu"
+require "torch/nn/sigmoid"
+require "torch/nn/softplus"
+
+# nn activations other
+require "torch/nn/log_softmax"
+require "torch/nn/softmax"
+require "torch/nn/softmax2d"
+require "torch/nn/softmin"
+
+# nn sparse layers
+require "torch/nn/embedding"
+require "torch/nn/embedding_bag"
+
+# nn distance functions
+require "torch/nn/cosine_similarity"
+require "torch/nn/pairwise_distance"
+
+# nn loss functions
+require "torch/nn/loss"
+require "torch/nn/weighted_loss"
+require "torch/nn/bce_loss"
+# require "torch/nn/bce_with_logits_loss"
+# require "torch/nn/cosine_embedding_loss"
+require "torch/nn/cross_entropy_loss"
+require "torch/nn/ctc_loss"
+# require "torch/nn/hinge_embedding_loss"
+require "torch/nn/kl_div_loss"
+require "torch/nn/l1_loss"
+# require "torch/nn/margin_ranking_loss"
+require "torch/nn/mse_loss"
+# require "torch/nn/multi_label_margin_loss"
+# require "torch/nn/multi_label_soft_margin_loss"
+# require "torch/nn/multi_margin_loss"
+require "torch/nn/nll_loss"
+require "torch/nn/poisson_nll_loss"
+# require "torch/nn/smooth_l1_loss"
+# require "torch/nn/soft_margin_loss"
+# require "torch/nn/triplet_margin_loss"
+
+# nn other
require "torch/nn/functional"
require "torch/nn/init"
-require "torch/nn/linear"
-require "torch/nn/mse_loss"
-require "torch/nn/parameter"
-require "torch/nn/relu"
-require "torch/nn/sequential"
# utils
require "torch/utils/data/data_loader"
require "torch/utils/data/tensor_dataset"
+# random
+require "torch/random"
+
module Torch
class Error < StandardError; end
class NotImplementedYet < StandardError
def message
"This feature has not been implemented yet. Consider submitting a PR."
end
end
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
- # complex and quantized types not supported by PyTorch yet
DTYPE_TO_ENUM = {
uint8: 0,
int8: 1,
short: 2,
int16: 2,
@@ -71,18 +129,18 @@
float16: 5,
float: 6,
float32: 6,
double: 7,
float64: 7,
- # complex_half: 8,
- # complex_float: 9,
- # complex_double: 10,
+ complex_half: 8,
+ complex_float: 9,
+ complex_double: 10,
bool: 11,
- # qint8: 12,
- # quint8: 13,
- # qint32: 14,
- # bfloat16: 15
+ qint8: 12,
+ quint8: 13,
+ qint32: 14,
+ bfloat16: 15
}
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
class << self
# Torch.float, Torch.long, etc
@@ -118,10 +176,12 @@
# private
# use method for cases when Numo not available
# or available after Torch loaded
def _dtype_to_numo
+ raise Error, "Numo not found" unless defined?(Numo::NArray)
+
{
uint8: Numo::UInt8,
int8: Numo::Int8,
int16: Numo::Int16,
int32: Numo::Int32,
@@ -198,12 +258,16 @@
data = data.flatten
else
data = [data].compact
end
- if options[:dtype].nil? && data.all? { |v| v.is_a?(Integer) }
- options[:dtype] = :int64
+ if options[:dtype].nil?
+ if data.all? { |v| v.is_a?(Integer) }
+ options[:dtype] = :int64
+ elsif data.all? { |v| v == true || v == false }
+ options[:dtype] = :bool
+ end
end
_tensor(data, size, tensor_options(**options))
end
@@ -300,10 +364,14 @@
def pow(input, exponent)
_pow(input, exponent)
end
+ def topk(input, k)
+ _topk(input, k)
+ end
+
def min(input)
_min(input)
end
def max(input, dim = nil, keepdim: false, out: nil)
@@ -325,10 +393,14 @@
def sign(input)
_sign(input)
end
+ def sigmoid(input)
+ _sigmoid(input)
+ end
+
def gt(input, other)
_gt(input, other)
end
def lt(input, other)
@@ -359,9 +431,18 @@
_flatten(input, start_dim, end_dim)
end
def sqrt(input)
_sqrt(input)
+ end
+
+ # TODO make dim keyword argument
+ def log_softmax(input, dim)
+ _log_softmax(input, dim)
+ end
+
+ def softmax(input, dim: nil)
+ _softmax(input, dim)
end
def abs(input)
_abs(input)
end