ext/torch/ext.cpp in torch-rb-0.1.3 vs ext/torch/ext.cpp in torch-rb-0.1.4

- old
+ new

@@ -130,10 +130,116 @@ TensorList from_ruby<TensorList>(Object x) { return TensorList(x); } +class FanModeType { + std::string s; + public: + FanModeType(Object o) { + s = String(o).str(); + } + // TODO switch NonlinearityType after LibTorch 1.4 release + operator torch::nn::init::FanMode() { + if (s == "fan_in") { + return torch::nn::init::FanMode::FanIn; + } else if (s == "fan_out") { + return torch::nn::init::FanMode::FanOut; + } else { + throw std::runtime_error("Unsupported nonlinearity type: " + s); + } + } +}; + +template<> +inline +FanModeType from_ruby<FanModeType>(Object x) +{ + return FanModeType(x); +} + +class NonlinearityType { + std::string s; + public: + NonlinearityType(Object o) { + s = String(o).str(); + } + // TODO switch NonlinearityType after LibTorch 1.4 release + operator torch::nn::init::Nonlinearity() { + if (s == "linear") { + return torch::nn::init::Nonlinearity::Linear; + } else if (s == "conv1d") { + return torch::nn::init::Nonlinearity::Conv1D; + } else if (s == "conv2d") { + return torch::nn::init::Nonlinearity::Conv2D; + } else if (s == "conv3d") { + return torch::nn::init::Nonlinearity::Conv3D; + } else if (s == "conv_transpose1d") { + return torch::nn::init::Nonlinearity::ConvTranspose1D; + } else if (s == "conv_transpose2d") { + return torch::nn::init::Nonlinearity::ConvTranspose2D; + } else if (s == "conv_transpose3d") { + return torch::nn::init::Nonlinearity::ConvTranspose3D; + } else if (s == "sigmoid") { + return torch::nn::init::Nonlinearity::Sigmoid; + } else if (s == "tanh") { + return torch::nn::init::Nonlinearity::Tanh; + } else if (s == "relu") { + return torch::nn::init::Nonlinearity::ReLU; + } else if (s == "leaky_relu") { + return torch::nn::init::Nonlinearity::LeakyReLU; + } else { + throw std::runtime_error("Unsupported nonlinearity type: " + s); + } + } +}; + +template<> +inline +NonlinearityType from_ruby<NonlinearityType>(Object x) +{ + return NonlinearityType(x); +} + +class MyReduction { + Object value; + public: + MyReduction(Object o) { + value = o; + } + operator int64_t() { + if (value.is_nil()) { + return Reduction::None; + } + + std::string s = String(value).str(); + if (s == "mean") { + return Reduction::Mean; + } else if (s == "sum") { + return Reduction::Sum; + } else { + throw std::runtime_error("Unsupported reduction: " + s); + } + } +}; + +template<> +inline +MyReduction from_ruby<MyReduction>(Object x) +{ + return MyReduction(x); +} + +typedef torch::Tensor Tensor; + +Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) { + Array a; + a.push(to_ruby<torch::Tensor>(std::get<0>(x))); + a.push(to_ruby<torch::Tensor>(std::get<1>(x))); + return Object(a); +} + extern "C" void Init_ext() { Module rb_mTorch = define_module("Torch") .define_singleton_method( @@ -146,11 +252,11 @@ *[](bool enabled) { torch::GradMode::set_enabled(enabled); }) .define_singleton_method( "floating_point?", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::is_floating_point(input); }) .define_singleton_method( "manual_seed", *[](uint64_t seed) { @@ -218,279 +324,362 @@ return torch::zeros(size, options); }) // begin operations .define_singleton_method( "_mean", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::mean(input); }) .define_singleton_method( "_mean_dim", - *[](torch::Tensor& input, int64_t dim, bool keepdim) { + *[](Tensor& input, int64_t dim, bool keepdim) { return torch::mean(input, dim, keepdim); }) .define_singleton_method( "_sum", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::sum(input); }) .define_singleton_method( "_sum_dim", - *[](torch::Tensor& input, int64_t dim, bool keepdim) { + *[](Tensor& input, int64_t dim, bool keepdim) { return torch::sum(input, dim, keepdim); }) .define_singleton_method( "_argmax", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::argmax(input); }) .define_singleton_method( "_argmax_dim", - *[](torch::Tensor& input, int64_t dim, bool keepdim) { + *[](Tensor& input, int64_t dim, bool keepdim) { return torch::argmax(input, dim, keepdim); }) .define_singleton_method( "_cat", *[](TensorList tensors, int64_t dim) { return torch::cat(tensors, dim); }) .define_singleton_method( "_norm", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::norm(input); }) .define_singleton_method( "_min", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::min(input); }) .define_singleton_method( "_max", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::max(input); }) .define_singleton_method( "_max_out", - *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) { - // TODO add return value - torch::_max_out(max, max_indices, input, dim, keepdim); + *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) { + return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim)); }) .define_singleton_method( "_sqrt", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::sqrt(input); }) .define_singleton_method( "_exp", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::exp(input); }) .define_singleton_method( "_log", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::log(input); }) .define_singleton_method( "_sign", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::sign(input); }) .define_singleton_method( "_unsqueeze", - *[](torch::Tensor& input, int64_t dim) { + *[](Tensor& input, int64_t dim) { return torch::unsqueeze(input, dim); }) .define_singleton_method( "_dot", - *[](torch::Tensor& input, torch::Tensor& tensor) { + *[](Tensor& input, Tensor& tensor) { return torch::dot(input, tensor); }) .define_singleton_method( "_matmul", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::matmul(input, other); }) .define_singleton_method( "_eq", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::eq(input, other); }) .define_singleton_method( "_gt", // TODO support tensors - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::gt(input, other); }) .define_singleton_method( "_lt", // TODO support tensors - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::lt(input, other); }) .define_singleton_method( "_add", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::add(input, other); }) .define_singleton_method( "_add_scalar", - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::add(input, other); }) .define_singleton_method( "_add_out", - *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& out, Tensor& input, Tensor& other) { return torch::add_out(out, input, other); }) .define_singleton_method( "_sub", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::sub(input, other); }) .define_singleton_method( "_sub_scalar", - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::sub(input, other); }) .define_singleton_method( "_mul", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::mul(input, other); }) .define_singleton_method( "_mul_scalar", - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::mul(input, other); }) .define_singleton_method( "_div", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::div(input, other); }) .define_singleton_method( "_div_scalar", - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::div(input, other); }) .define_singleton_method( "_remainder", - *[](torch::Tensor& input, torch::Tensor& other) { + *[](Tensor& input, Tensor& other) { return torch::remainder(input, other); }) .define_singleton_method( "_remainder_scalar", - *[](torch::Tensor& input, Scalar other) { + *[](Tensor& input, Scalar other) { return torch::remainder(input, other); }) .define_singleton_method( "_pow", - *[](torch::Tensor& input, Scalar exponent) { + *[](Tensor& input, Scalar exponent) { return torch::pow(input, exponent); }) .define_singleton_method( + "_topk", + *[](Tensor& input, int64_t k) { + return tensor_array(torch::topk(input, k)); + }) + .define_singleton_method( + "_sigmoid", + *[](Tensor& input) { + return torch::sigmoid(input); + }) + .define_singleton_method( + "_softplus", + *[](const Tensor &input, Scalar beta, Scalar threshold) { + return torch::softplus(input, beta, threshold); + }) + .define_singleton_method( + "_softmax", + *[](const Tensor &input, int64_t dim) { + return torch::softmax(input, dim); + }) + .define_singleton_method( + "_log_softmax", + *[](Tensor& input, int64_t dim) { + return torch::log_softmax(input, dim); + }) + .define_singleton_method( "_abs", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::abs(input); }) .define_singleton_method( "_neg", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::neg(input); }) .define_singleton_method( "_reshape", - *[](torch::Tensor& input, IntArrayRef shape) { + *[](Tensor& input, IntArrayRef shape) { return torch::reshape(input, shape); }) .define_singleton_method( "_flatten", - *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) { + *[](Tensor& input, int64_t start_dim, int64_t end_dim) { return torch::flatten(input, start_dim, end_dim); }) .define_singleton_method( "relu", - *[](torch::Tensor& input) { + *[](Tensor& input) { return torch::relu(input); }) .define_singleton_method( + "prelu", + *[](torch::Tensor& input, torch::Tensor& weight) { + return torch::prelu(input, weight); + }) + .define_singleton_method( + "leaky_relu", + *[](torch::Tensor& input, Scalar negative_slope) { + return torch::leaky_relu(input, negative_slope); + }) + .define_singleton_method( "conv2d", - *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { + *[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { return torch::conv2d(input, weight, bias, stride, padding, dilation, groups); }) + // linear layers .define_singleton_method( + "bilinear", + *[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) { + return torch::bilinear(input1, input2, weight, bias); + }) + .define_singleton_method( "linear", - *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) { + *[](Tensor& input, Tensor& weight, Tensor& bias) { return torch::linear(input, weight, bias); }) + // pooling layers .define_singleton_method( "max_pool2d", - *[](torch::Tensor& input, IntArrayRef kernel_size) { + *[](Tensor& input, IntArrayRef kernel_size) { return torch::max_pool2d(input, kernel_size); }) .define_singleton_method( "avg_pool2d", - *[](torch::Tensor& input, IntArrayRef kernel_size) { + *[](Tensor& input, IntArrayRef kernel_size) { return torch::avg_pool2d(input, kernel_size); }) .define_singleton_method( "_dropout", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::dropout(input, p, train); }) .define_singleton_method( "_dropout!", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::dropout_(input, p, train); }) .define_singleton_method( "_feature_dropout", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::feature_dropout(input, p, train); }) .define_singleton_method( "_feature_dropout!", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::feature_dropout_(input, p, train); }) .define_singleton_method( "_alpha_dropout", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::alpha_dropout(input, p, train); }) .define_singleton_method( "_alpha_dropout!", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::alpha_dropout_(input, p, train); }) .define_singleton_method( "_feature_alpha_dropout", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::feature_alpha_dropout(input, p, train); }) .define_singleton_method( "_feature_alpha_dropout!", - *[](torch::Tensor& input, float p, bool train) { + *[](Tensor& input, float p, bool train) { return torch::feature_alpha_dropout_(input, p, train); }) + // sparse layers .define_singleton_method( "_embedding", // weight and indices are swapped from Python interface - *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + *[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); }) .define_singleton_method( + "_embedding_bag", + // weight and indices are swapped from Python interface + *[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) { + return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights); + }) + // distance functions + .define_singleton_method( + "_cosine_similarity", + *[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) { + return torch::cosine_similarity(x1, x2, dim, eps); + }) + .define_singleton_method( + "_pairwise_distance", + *[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) { + return torch::pairwise_distance(x1, x2, p, eps, keepdim); + }) + // loss functions + .define_singleton_method( + "binary_cross_entropy", + *[](Tensor& input, Tensor& target, MyReduction reduction) { + return torch::binary_cross_entropy(input, target, {}, reduction); + }) + .define_singleton_method( + "ctc_loss", + *[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) { + return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + }) + .define_singleton_method( + "kl_div", + *[](Tensor& input, Tensor& target, MyReduction reduction) { + return torch::kl_div(input, target, reduction); + }) + .define_singleton_method( + "l1_loss", + *[](Tensor& input, Tensor& target, MyReduction reduction) { + return torch::l1_loss(input, target, reduction); + }) + .define_singleton_method( "mse_loss", - *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) { - auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum; - return torch::mse_loss(input, target, red); + *[](Tensor& input, Tensor& target, MyReduction reduction) { + return torch::mse_loss(input, target, reduction); }) .define_singleton_method( "nll_loss", - *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) { - auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum; - return torch::nll_loss(input, target, {}, red); + *[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) { + return torch::nll_loss(input, target, {}, reduction, ignore_index); }) + .define_singleton_method( + "poisson_nll_loss", + *[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) { + return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction); + }) .define_singleton_method("numel", &torch::numel) .define_singleton_method( "_from_blob", *[](String s, IntArrayRef size, const torch::TensorOptions &options) { void *data = const_cast<char *>(s.c_str()); @@ -498,15 +687,22 @@ }) .define_singleton_method( "_tensor", *[](Object o, IntArrayRef size, const torch::TensorOptions &options) { Array a = Array(o); - std::vector<float> vec; - for (size_t i = 0; i < a.size(); i++) { - vec.push_back(from_ruby<float>(a[i])); + auto dtype = options.dtype(); + torch::Tensor t; + if (dtype == torch::kBool) { + throw std::runtime_error("Cannot create bool from tensor method yet"); + } else { + std::vector<float> vec; + for (size_t i = 0; i < a.size(); i++) { + vec.push_back(from_ruby<float>(a[i])); + } + t = torch::tensor(vec, options); } - return torch::tensor(vec, options).reshape(size); + return t.reshape(size); }); Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor") .define_method("cuda?", &torch::Tensor::is_cuda) .define_method("distributed?", &torch::Tensor::is_distributed) @@ -519,170 +715,175 @@ .define_method("element_size", &torch::Tensor::element_size) .define_method("requires_grad", &torch::Tensor::requires_grad) .define_method("view_as", &torch::Tensor::view_as) .define_method( "addcmul!", - *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) { + *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) { return self.addcmul_(tensor1, tensor2, value); }) .define_method( "addcdiv!", - *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) { + *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) { return self.addcdiv_(tensor1, tensor2, value); }) .define_method( "zero!", - *[](torch::Tensor& self) { + *[](Tensor& self) { return self.zero_(); }) .define_method( + "detach", + *[](Tensor& self) { + return self.detach(); + }) + .define_method( "detach!", - *[](torch::Tensor& self) { + *[](Tensor& self) { return self.detach_(); }) .define_method( "_select", - *[](torch::Tensor& self, int64_t dim, int64_t index) { + *[](Tensor& self, int64_t dim, int64_t index) { return self.select(dim, index); }) .define_method( "_slice", - *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { + *[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { return self.slice(dim, start, end, step); }) .define_method( "_requires_grad!", - *[](torch::Tensor& self, bool requires_grad) { + *[](Tensor& self, bool requires_grad) { return self.set_requires_grad(requires_grad); }) .define_method( "_backward", - *[](torch::Tensor& self) { - return self.backward(); + *[](Tensor& self, Object gradient) { + return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient)); }) .define_method( - "_backward_gradient", - *[](torch::Tensor& self, const torch::Tensor& gradient) { - return self.backward(gradient); - }) - .define_method( "grad", - *[](torch::Tensor& self) { + *[](Tensor& self) { return self.grad(); }) .define_method( "_dtype", - *[](torch::Tensor& self) { + *[](Tensor& self) { return (int) at::typeMetaToScalarType(self.dtype()); }) .define_method( "_type", - *[](torch::Tensor& self, int dtype) { + *[](Tensor& self, int dtype) { return self.toType((torch::ScalarType) dtype); }) .define_method( "_layout", - *[](torch::Tensor& self) { + *[](Tensor& self) { std::stringstream s; s << self.layout(); return s.str(); }) .define_method( "device", - *[](torch::Tensor& self) { + *[](Tensor& self) { std::stringstream s; s << self.device(); return s.str(); }) .define_method( "_view", - *[](torch::Tensor& self, IntArrayRef size) { + *[](Tensor& self, IntArrayRef size) { return self.view(size); }) .define_method( "resize_as!", - *[](torch::Tensor& self, torch::Tensor& other) { + *[](Tensor& self, Tensor& other) { return self.resize_as_(other); }) .define_method( "fill!", - *[](torch::Tensor& self, Scalar value) { + *[](Tensor& self, Scalar value) { return self.fill_(value); }) .define_method( + "relu!", + *[](Tensor& self) { + return self.relu_(); + }) + .define_method( "_add!", - *[](torch::Tensor& self, torch::Tensor& other) { + *[](Tensor& self, Tensor& other) { return self.add_(other); }) .define_method( "_add_alpha!", - *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) { + *[](Tensor& self, Tensor& other, Scalar alpha) { return self.add_(other, alpha); }) .define_method( "_add_scalar!", - *[](torch::Tensor& self, Scalar other) { + *[](Tensor& self, Scalar other) { return self.add_(other); }) .define_method( "normal!", - *[](torch::Tensor& self, double mean, double std) { + *[](Tensor& self, double mean, double std) { return self.normal_(mean, std); }) .define_method( + "random!", + *[](Tensor& self, int64_t to) { + return self.random_(to); + }) + .define_method( "sub!", - *[](torch::Tensor& self, torch::Tensor& other) { + *[](Tensor& self, Tensor& other) { return self.sub_(other); }) .define_method( "_mul!", - *[](torch::Tensor& self, torch::Tensor& other) { + *[](Tensor& self, Tensor& other) { return self.mul_(other); }) .define_method( "_mul_scalar!", - *[](torch::Tensor& self, Scalar other) { + *[](Tensor& self, Scalar other) { return self.mul_(other); }) .define_method( "div!", - *[](torch::Tensor& self, torch::Tensor& other) { + *[](Tensor& self, Tensor& other) { return self.div_(other); }) .define_method( "sqrt!", - *[](torch::Tensor& self) { + *[](Tensor& self) { return self.sqrt_(); }) .define_method( "unsqueeze!", - *[](torch::Tensor& self, int64_t dim) { + *[](Tensor& self, int64_t dim) { return self.unsqueeze_(dim); }) .define_method( "copy!", - *[](torch::Tensor& self, torch::Tensor& src) { + *[](Tensor& self, Tensor& src) { return self.copy_(src); }) .define_method( "clone", - *[](torch::Tensor& self) { + *[](Tensor& self) { return self.clone(); }) .define_method( - "log_softmax", - *[](torch::Tensor& self, int64_t dim) { - return self.log_softmax(dim); - }) - .define_method( "data", - *[](torch::Tensor& self) { + *[](Tensor& self) { return self.data(); }) .define_method( "_data", - *[](torch::Tensor& self) { + *[](Tensor& self) { Array a; auto dtype = self.dtype(); // TODO DRY if someone knows C++ if (dtype == torch::kByte) { @@ -730,21 +931,21 @@ } return a; }) .define_method( "_size", - *[](torch::Tensor& self, int i) { + *[](Tensor& self, int i) { return self.size(i); }) .define_method( "_to", - *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) { + *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) { return self.to(device, (torch::ScalarType) dtype, non_blocking, copy); }) .define_singleton_method( "_make_subclass", - *[](torch::Tensor& rd, bool requires_grad) { + *[](Tensor& rd, bool requires_grad) { auto data = torch::autograd::as_variable_ref(rd).detach(); data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); auto var = data.set_requires_grad(requires_grad); return torch::autograd::Variable(std::move(var)); }); @@ -791,35 +992,85 @@ Module rb_mNN = define_module_under(rb_mTorch, "NN"); Module rb_mInit = define_module_under(rb_mNN, "Init") .define_singleton_method( - "kaiming_uniform!", - *[](torch::Tensor& input, double a) { - return torch::nn::init::kaiming_uniform_(input, a); + "_calculate_gain", + *[](NonlinearityType nonlinearity, double param) { + return torch::nn::init::calculate_gain(nonlinearity, param); }) .define_singleton_method( - "normal!", - *[](torch::Tensor& input) { - return torch::nn::init::normal_(input); + "_uniform!", + *[](Tensor tensor, double low, double high) { + return torch::nn::init::uniform_(tensor, low, high); }) .define_singleton_method( - "uniform!", - *[](torch::Tensor& input, double to, double from) { - return torch::nn::init::uniform_(input, to, from); + "_normal!", + *[](Tensor tensor, double mean, double std) { + return torch::nn::init::normal_(tensor, mean, std); + }) + .define_singleton_method( + "_constant!", + *[](Tensor tensor, Scalar value) { + return torch::nn::init::constant_(tensor, value); + }) + .define_singleton_method( + "_ones!", + *[](Tensor tensor) { + return torch::nn::init::ones_(tensor); + }) + .define_singleton_method( + "_zeros!", + *[](Tensor tensor) { + return torch::nn::init::zeros_(tensor); + }) + .define_singleton_method( + "_eye!", + *[](Tensor tensor) { + return torch::nn::init::eye_(tensor); + }) + .define_singleton_method( + "_dirac!", + *[](Tensor tensor) { + return torch::nn::init::dirac_(tensor); + }) + .define_singleton_method( + "_xavier_uniform!", + *[](Tensor tensor, double gain) { + return torch::nn::init::xavier_uniform_(tensor, gain); + }) + .define_singleton_method( + "_xavier_normal!", + *[](Tensor tensor, double gain) { + return torch::nn::init::xavier_normal_(tensor, gain); + }) + .define_singleton_method( + "_kaiming_uniform!", + *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) { + return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity); + }) + .define_singleton_method( + "_kaiming_normal!", + *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) { + return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity); + }) + .define_singleton_method( + "_orthogonal!", + *[](Tensor tensor, double gain) { + return torch::nn::init::orthogonal_(tensor, gain); + }) + .define_singleton_method( + "_sparse!", + *[](Tensor tensor, double sparsity, double std) { + return torch::nn::init::sparse_(tensor, sparsity, std); }); Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter") - // TODO return grad or nil to remove need for 2nd function .define_method( - "_grad", + "grad", *[](torch::autograd::Variable& self) { - return self.grad(); - }) - .define_method( - "_grad_defined", - *[](torch::autograd::Variable& self) { - return self.grad().defined(); + auto grad = self.grad(); + return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil; }); Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device") .define_constructor(Constructor<torch::Device, std::string>()) .define_method("index", &torch::Device::index)