ext/torch/ext.cpp in torch-rb-0.1.3 vs ext/torch/ext.cpp in torch-rb-0.1.4
- old
+ new
@@ -130,10 +130,116 @@
TensorList from_ruby<TensorList>(Object x)
{
return TensorList(x);
}
+class FanModeType {
+ std::string s;
+ public:
+ FanModeType(Object o) {
+ s = String(o).str();
+ }
+ // TODO switch NonlinearityType after LibTorch 1.4 release
+ operator torch::nn::init::FanMode() {
+ if (s == "fan_in") {
+ return torch::nn::init::FanMode::FanIn;
+ } else if (s == "fan_out") {
+ return torch::nn::init::FanMode::FanOut;
+ } else {
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
+ }
+ }
+};
+
+template<>
+inline
+FanModeType from_ruby<FanModeType>(Object x)
+{
+ return FanModeType(x);
+}
+
+class NonlinearityType {
+ std::string s;
+ public:
+ NonlinearityType(Object o) {
+ s = String(o).str();
+ }
+ // TODO switch NonlinearityType after LibTorch 1.4 release
+ operator torch::nn::init::Nonlinearity() {
+ if (s == "linear") {
+ return torch::nn::init::Nonlinearity::Linear;
+ } else if (s == "conv1d") {
+ return torch::nn::init::Nonlinearity::Conv1D;
+ } else if (s == "conv2d") {
+ return torch::nn::init::Nonlinearity::Conv2D;
+ } else if (s == "conv3d") {
+ return torch::nn::init::Nonlinearity::Conv3D;
+ } else if (s == "conv_transpose1d") {
+ return torch::nn::init::Nonlinearity::ConvTranspose1D;
+ } else if (s == "conv_transpose2d") {
+ return torch::nn::init::Nonlinearity::ConvTranspose2D;
+ } else if (s == "conv_transpose3d") {
+ return torch::nn::init::Nonlinearity::ConvTranspose3D;
+ } else if (s == "sigmoid") {
+ return torch::nn::init::Nonlinearity::Sigmoid;
+ } else if (s == "tanh") {
+ return torch::nn::init::Nonlinearity::Tanh;
+ } else if (s == "relu") {
+ return torch::nn::init::Nonlinearity::ReLU;
+ } else if (s == "leaky_relu") {
+ return torch::nn::init::Nonlinearity::LeakyReLU;
+ } else {
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
+ }
+ }
+};
+
+template<>
+inline
+NonlinearityType from_ruby<NonlinearityType>(Object x)
+{
+ return NonlinearityType(x);
+}
+
+class MyReduction {
+ Object value;
+ public:
+ MyReduction(Object o) {
+ value = o;
+ }
+ operator int64_t() {
+ if (value.is_nil()) {
+ return Reduction::None;
+ }
+
+ std::string s = String(value).str();
+ if (s == "mean") {
+ return Reduction::Mean;
+ } else if (s == "sum") {
+ return Reduction::Sum;
+ } else {
+ throw std::runtime_error("Unsupported reduction: " + s);
+ }
+ }
+};
+
+template<>
+inline
+MyReduction from_ruby<MyReduction>(Object x)
+{
+ return MyReduction(x);
+}
+
+typedef torch::Tensor Tensor;
+
+Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
+ Array a;
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
+ return Object(a);
+}
+
extern "C"
void Init_ext()
{
Module rb_mTorch = define_module("Torch")
.define_singleton_method(
@@ -146,11 +252,11 @@
*[](bool enabled) {
torch::GradMode::set_enabled(enabled);
})
.define_singleton_method(
"floating_point?",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::is_floating_point(input);
})
.define_singleton_method(
"manual_seed",
*[](uint64_t seed) {
@@ -218,279 +324,362 @@
return torch::zeros(size, options);
})
// begin operations
.define_singleton_method(
"_mean",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::mean(input);
})
.define_singleton_method(
"_mean_dim",
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
+ *[](Tensor& input, int64_t dim, bool keepdim) {
return torch::mean(input, dim, keepdim);
})
.define_singleton_method(
"_sum",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::sum(input);
})
.define_singleton_method(
"_sum_dim",
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
+ *[](Tensor& input, int64_t dim, bool keepdim) {
return torch::sum(input, dim, keepdim);
})
.define_singleton_method(
"_argmax",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::argmax(input);
})
.define_singleton_method(
"_argmax_dim",
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
+ *[](Tensor& input, int64_t dim, bool keepdim) {
return torch::argmax(input, dim, keepdim);
})
.define_singleton_method(
"_cat",
*[](TensorList tensors, int64_t dim) {
return torch::cat(tensors, dim);
})
.define_singleton_method(
"_norm",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::norm(input);
})
.define_singleton_method(
"_min",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::min(input);
})
.define_singleton_method(
"_max",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::max(input);
})
.define_singleton_method(
"_max_out",
- *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
- // TODO add return value
- torch::_max_out(max, max_indices, input, dim, keepdim);
+ *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
+ return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
})
.define_singleton_method(
"_sqrt",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::sqrt(input);
})
.define_singleton_method(
"_exp",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::exp(input);
})
.define_singleton_method(
"_log",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::log(input);
})
.define_singleton_method(
"_sign",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::sign(input);
})
.define_singleton_method(
"_unsqueeze",
- *[](torch::Tensor& input, int64_t dim) {
+ *[](Tensor& input, int64_t dim) {
return torch::unsqueeze(input, dim);
})
.define_singleton_method(
"_dot",
- *[](torch::Tensor& input, torch::Tensor& tensor) {
+ *[](Tensor& input, Tensor& tensor) {
return torch::dot(input, tensor);
})
.define_singleton_method(
"_matmul",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::matmul(input, other);
})
.define_singleton_method(
"_eq",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::eq(input, other);
})
.define_singleton_method(
"_gt",
// TODO support tensors
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::gt(input, other);
})
.define_singleton_method(
"_lt",
// TODO support tensors
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::lt(input, other);
})
.define_singleton_method(
"_add",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::add(input, other);
})
.define_singleton_method(
"_add_scalar",
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::add(input, other);
})
.define_singleton_method(
"_add_out",
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& out, Tensor& input, Tensor& other) {
return torch::add_out(out, input, other);
})
.define_singleton_method(
"_sub",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::sub(input, other);
})
.define_singleton_method(
"_sub_scalar",
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::sub(input, other);
})
.define_singleton_method(
"_mul",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::mul(input, other);
})
.define_singleton_method(
"_mul_scalar",
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::mul(input, other);
})
.define_singleton_method(
"_div",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::div(input, other);
})
.define_singleton_method(
"_div_scalar",
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::div(input, other);
})
.define_singleton_method(
"_remainder",
- *[](torch::Tensor& input, torch::Tensor& other) {
+ *[](Tensor& input, Tensor& other) {
return torch::remainder(input, other);
})
.define_singleton_method(
"_remainder_scalar",
- *[](torch::Tensor& input, Scalar other) {
+ *[](Tensor& input, Scalar other) {
return torch::remainder(input, other);
})
.define_singleton_method(
"_pow",
- *[](torch::Tensor& input, Scalar exponent) {
+ *[](Tensor& input, Scalar exponent) {
return torch::pow(input, exponent);
})
.define_singleton_method(
+ "_topk",
+ *[](Tensor& input, int64_t k) {
+ return tensor_array(torch::topk(input, k));
+ })
+ .define_singleton_method(
+ "_sigmoid",
+ *[](Tensor& input) {
+ return torch::sigmoid(input);
+ })
+ .define_singleton_method(
+ "_softplus",
+ *[](const Tensor &input, Scalar beta, Scalar threshold) {
+ return torch::softplus(input, beta, threshold);
+ })
+ .define_singleton_method(
+ "_softmax",
+ *[](const Tensor &input, int64_t dim) {
+ return torch::softmax(input, dim);
+ })
+ .define_singleton_method(
+ "_log_softmax",
+ *[](Tensor& input, int64_t dim) {
+ return torch::log_softmax(input, dim);
+ })
+ .define_singleton_method(
"_abs",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::abs(input);
})
.define_singleton_method(
"_neg",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::neg(input);
})
.define_singleton_method(
"_reshape",
- *[](torch::Tensor& input, IntArrayRef shape) {
+ *[](Tensor& input, IntArrayRef shape) {
return torch::reshape(input, shape);
})
.define_singleton_method(
"_flatten",
- *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
+ *[](Tensor& input, int64_t start_dim, int64_t end_dim) {
return torch::flatten(input, start_dim, end_dim);
})
.define_singleton_method(
"relu",
- *[](torch::Tensor& input) {
+ *[](Tensor& input) {
return torch::relu(input);
})
.define_singleton_method(
+ "prelu",
+ *[](torch::Tensor& input, torch::Tensor& weight) {
+ return torch::prelu(input, weight);
+ })
+ .define_singleton_method(
+ "leaky_relu",
+ *[](torch::Tensor& input, Scalar negative_slope) {
+ return torch::leaky_relu(input, negative_slope);
+ })
+ .define_singleton_method(
"conv2d",
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
+ *[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
})
+ // linear layers
.define_singleton_method(
+ "bilinear",
+ *[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
+ return torch::bilinear(input1, input2, weight, bias);
+ })
+ .define_singleton_method(
"linear",
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
+ *[](Tensor& input, Tensor& weight, Tensor& bias) {
return torch::linear(input, weight, bias);
})
+ // pooling layers
.define_singleton_method(
"max_pool2d",
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
+ *[](Tensor& input, IntArrayRef kernel_size) {
return torch::max_pool2d(input, kernel_size);
})
.define_singleton_method(
"avg_pool2d",
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
+ *[](Tensor& input, IntArrayRef kernel_size) {
return torch::avg_pool2d(input, kernel_size);
})
.define_singleton_method(
"_dropout",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::dropout(input, p, train);
})
.define_singleton_method(
"_dropout!",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::dropout_(input, p, train);
})
.define_singleton_method(
"_feature_dropout",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::feature_dropout(input, p, train);
})
.define_singleton_method(
"_feature_dropout!",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::feature_dropout_(input, p, train);
})
.define_singleton_method(
"_alpha_dropout",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::alpha_dropout(input, p, train);
})
.define_singleton_method(
"_alpha_dropout!",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::alpha_dropout_(input, p, train);
})
.define_singleton_method(
"_feature_alpha_dropout",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::feature_alpha_dropout(input, p, train);
})
.define_singleton_method(
"_feature_alpha_dropout!",
- *[](torch::Tensor& input, float p, bool train) {
+ *[](Tensor& input, float p, bool train) {
return torch::feature_alpha_dropout_(input, p, train);
})
+ // sparse layers
.define_singleton_method(
"_embedding",
// weight and indices are swapped from Python interface
- *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
+ *[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
})
.define_singleton_method(
+ "_embedding_bag",
+ // weight and indices are swapped from Python interface
+ *[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
+ return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
+ })
+ // distance functions
+ .define_singleton_method(
+ "_cosine_similarity",
+ *[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
+ return torch::cosine_similarity(x1, x2, dim, eps);
+ })
+ .define_singleton_method(
+ "_pairwise_distance",
+ *[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
+ return torch::pairwise_distance(x1, x2, p, eps, keepdim);
+ })
+ // loss functions
+ .define_singleton_method(
+ "binary_cross_entropy",
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
+ return torch::binary_cross_entropy(input, target, {}, reduction);
+ })
+ .define_singleton_method(
+ "ctc_loss",
+ *[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
+ return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
+ })
+ .define_singleton_method(
+ "kl_div",
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
+ return torch::kl_div(input, target, reduction);
+ })
+ .define_singleton_method(
+ "l1_loss",
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
+ return torch::l1_loss(input, target, reduction);
+ })
+ .define_singleton_method(
"mse_loss",
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
- return torch::mse_loss(input, target, red);
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
+ return torch::mse_loss(input, target, reduction);
})
.define_singleton_method(
"nll_loss",
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
- return torch::nll_loss(input, target, {}, red);
+ *[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
+ return torch::nll_loss(input, target, {}, reduction, ignore_index);
})
+ .define_singleton_method(
+ "poisson_nll_loss",
+ *[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
+ return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
+ })
.define_singleton_method("numel", &torch::numel)
.define_singleton_method(
"_from_blob",
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
void *data = const_cast<char *>(s.c_str());
@@ -498,15 +687,22 @@
})
.define_singleton_method(
"_tensor",
*[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
Array a = Array(o);
- std::vector<float> vec;
- for (size_t i = 0; i < a.size(); i++) {
- vec.push_back(from_ruby<float>(a[i]));
+ auto dtype = options.dtype();
+ torch::Tensor t;
+ if (dtype == torch::kBool) {
+ throw std::runtime_error("Cannot create bool from tensor method yet");
+ } else {
+ std::vector<float> vec;
+ for (size_t i = 0; i < a.size(); i++) {
+ vec.push_back(from_ruby<float>(a[i]));
+ }
+ t = torch::tensor(vec, options);
}
- return torch::tensor(vec, options).reshape(size);
+ return t.reshape(size);
});
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
.define_method("cuda?", &torch::Tensor::is_cuda)
.define_method("distributed?", &torch::Tensor::is_distributed)
@@ -519,170 +715,175 @@
.define_method("element_size", &torch::Tensor::element_size)
.define_method("requires_grad", &torch::Tensor::requires_grad)
.define_method("view_as", &torch::Tensor::view_as)
.define_method(
"addcmul!",
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
return self.addcmul_(tensor1, tensor2, value);
})
.define_method(
"addcdiv!",
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
return self.addcdiv_(tensor1, tensor2, value);
})
.define_method(
"zero!",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return self.zero_();
})
.define_method(
+ "detach",
+ *[](Tensor& self) {
+ return self.detach();
+ })
+ .define_method(
"detach!",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return self.detach_();
})
.define_method(
"_select",
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
+ *[](Tensor& self, int64_t dim, int64_t index) {
return self.select(dim, index);
})
.define_method(
"_slice",
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
+ *[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
return self.slice(dim, start, end, step);
})
.define_method(
"_requires_grad!",
- *[](torch::Tensor& self, bool requires_grad) {
+ *[](Tensor& self, bool requires_grad) {
return self.set_requires_grad(requires_grad);
})
.define_method(
"_backward",
- *[](torch::Tensor& self) {
- return self.backward();
+ *[](Tensor& self, Object gradient) {
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
})
.define_method(
- "_backward_gradient",
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
- return self.backward(gradient);
- })
- .define_method(
"grad",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return self.grad();
})
.define_method(
"_dtype",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return (int) at::typeMetaToScalarType(self.dtype());
})
.define_method(
"_type",
- *[](torch::Tensor& self, int dtype) {
+ *[](Tensor& self, int dtype) {
return self.toType((torch::ScalarType) dtype);
})
.define_method(
"_layout",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
std::stringstream s;
s << self.layout();
return s.str();
})
.define_method(
"device",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
std::stringstream s;
s << self.device();
return s.str();
})
.define_method(
"_view",
- *[](torch::Tensor& self, IntArrayRef size) {
+ *[](Tensor& self, IntArrayRef size) {
return self.view(size);
})
.define_method(
"resize_as!",
- *[](torch::Tensor& self, torch::Tensor& other) {
+ *[](Tensor& self, Tensor& other) {
return self.resize_as_(other);
})
.define_method(
"fill!",
- *[](torch::Tensor& self, Scalar value) {
+ *[](Tensor& self, Scalar value) {
return self.fill_(value);
})
.define_method(
+ "relu!",
+ *[](Tensor& self) {
+ return self.relu_();
+ })
+ .define_method(
"_add!",
- *[](torch::Tensor& self, torch::Tensor& other) {
+ *[](Tensor& self, Tensor& other) {
return self.add_(other);
})
.define_method(
"_add_alpha!",
- *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
+ *[](Tensor& self, Tensor& other, Scalar alpha) {
return self.add_(other, alpha);
})
.define_method(
"_add_scalar!",
- *[](torch::Tensor& self, Scalar other) {
+ *[](Tensor& self, Scalar other) {
return self.add_(other);
})
.define_method(
"normal!",
- *[](torch::Tensor& self, double mean, double std) {
+ *[](Tensor& self, double mean, double std) {
return self.normal_(mean, std);
})
.define_method(
+ "random!",
+ *[](Tensor& self, int64_t to) {
+ return self.random_(to);
+ })
+ .define_method(
"sub!",
- *[](torch::Tensor& self, torch::Tensor& other) {
+ *[](Tensor& self, Tensor& other) {
return self.sub_(other);
})
.define_method(
"_mul!",
- *[](torch::Tensor& self, torch::Tensor& other) {
+ *[](Tensor& self, Tensor& other) {
return self.mul_(other);
})
.define_method(
"_mul_scalar!",
- *[](torch::Tensor& self, Scalar other) {
+ *[](Tensor& self, Scalar other) {
return self.mul_(other);
})
.define_method(
"div!",
- *[](torch::Tensor& self, torch::Tensor& other) {
+ *[](Tensor& self, Tensor& other) {
return self.div_(other);
})
.define_method(
"sqrt!",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return self.sqrt_();
})
.define_method(
"unsqueeze!",
- *[](torch::Tensor& self, int64_t dim) {
+ *[](Tensor& self, int64_t dim) {
return self.unsqueeze_(dim);
})
.define_method(
"copy!",
- *[](torch::Tensor& self, torch::Tensor& src) {
+ *[](Tensor& self, Tensor& src) {
return self.copy_(src);
})
.define_method(
"clone",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return self.clone();
})
.define_method(
- "log_softmax",
- *[](torch::Tensor& self, int64_t dim) {
- return self.log_softmax(dim);
- })
- .define_method(
"data",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
return self.data();
})
.define_method(
"_data",
- *[](torch::Tensor& self) {
+ *[](Tensor& self) {
Array a;
auto dtype = self.dtype();
// TODO DRY if someone knows C++
if (dtype == torch::kByte) {
@@ -730,21 +931,21 @@
}
return a;
})
.define_method(
"_size",
- *[](torch::Tensor& self, int i) {
+ *[](Tensor& self, int i) {
return self.size(i);
})
.define_method(
"_to",
- *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
})
.define_singleton_method(
"_make_subclass",
- *[](torch::Tensor& rd, bool requires_grad) {
+ *[](Tensor& rd, bool requires_grad) {
auto data = torch::autograd::as_variable_ref(rd).detach();
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
auto var = data.set_requires_grad(requires_grad);
return torch::autograd::Variable(std::move(var));
});
@@ -791,35 +992,85 @@
Module rb_mNN = define_module_under(rb_mTorch, "NN");
Module rb_mInit = define_module_under(rb_mNN, "Init")
.define_singleton_method(
- "kaiming_uniform!",
- *[](torch::Tensor& input, double a) {
- return torch::nn::init::kaiming_uniform_(input, a);
+ "_calculate_gain",
+ *[](NonlinearityType nonlinearity, double param) {
+ return torch::nn::init::calculate_gain(nonlinearity, param);
})
.define_singleton_method(
- "normal!",
- *[](torch::Tensor& input) {
- return torch::nn::init::normal_(input);
+ "_uniform!",
+ *[](Tensor tensor, double low, double high) {
+ return torch::nn::init::uniform_(tensor, low, high);
})
.define_singleton_method(
- "uniform!",
- *[](torch::Tensor& input, double to, double from) {
- return torch::nn::init::uniform_(input, to, from);
+ "_normal!",
+ *[](Tensor tensor, double mean, double std) {
+ return torch::nn::init::normal_(tensor, mean, std);
+ })
+ .define_singleton_method(
+ "_constant!",
+ *[](Tensor tensor, Scalar value) {
+ return torch::nn::init::constant_(tensor, value);
+ })
+ .define_singleton_method(
+ "_ones!",
+ *[](Tensor tensor) {
+ return torch::nn::init::ones_(tensor);
+ })
+ .define_singleton_method(
+ "_zeros!",
+ *[](Tensor tensor) {
+ return torch::nn::init::zeros_(tensor);
+ })
+ .define_singleton_method(
+ "_eye!",
+ *[](Tensor tensor) {
+ return torch::nn::init::eye_(tensor);
+ })
+ .define_singleton_method(
+ "_dirac!",
+ *[](Tensor tensor) {
+ return torch::nn::init::dirac_(tensor);
+ })
+ .define_singleton_method(
+ "_xavier_uniform!",
+ *[](Tensor tensor, double gain) {
+ return torch::nn::init::xavier_uniform_(tensor, gain);
+ })
+ .define_singleton_method(
+ "_xavier_normal!",
+ *[](Tensor tensor, double gain) {
+ return torch::nn::init::xavier_normal_(tensor, gain);
+ })
+ .define_singleton_method(
+ "_kaiming_uniform!",
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
+ })
+ .define_singleton_method(
+ "_kaiming_normal!",
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
+ })
+ .define_singleton_method(
+ "_orthogonal!",
+ *[](Tensor tensor, double gain) {
+ return torch::nn::init::orthogonal_(tensor, gain);
+ })
+ .define_singleton_method(
+ "_sparse!",
+ *[](Tensor tensor, double sparsity, double std) {
+ return torch::nn::init::sparse_(tensor, sparsity, std);
});
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
- // TODO return grad or nil to remove need for 2nd function
.define_method(
- "_grad",
+ "grad",
*[](torch::autograd::Variable& self) {
- return self.grad();
- })
- .define_method(
- "_grad_defined",
- *[](torch::autograd::Variable& self) {
- return self.grad().defined();
+ auto grad = self.grad();
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
});
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
.define_constructor(Constructor<torch::Device, std::string>())
.define_method("index", &torch::Device::index)