ext/torch/nn.cpp in torch-rb-0.6.0 vs ext/torch/nn.cpp in torch-rb-0.7.0
- old
+ new
@@ -1,8 +1,8 @@
#include <torch/torch.h>
-#include <rice/Module.hpp>
+#include <rice/rice.hpp>
#include "nn_functions.h"
#include "templates.h"
#include "utils.h"
@@ -17,96 +17,96 @@
rb_mNN.add_handler<torch::Error>(handle_error);
add_nn_functions(rb_mNN);
Rice::define_module_under(rb_mNN, "Init")
.add_handler<torch::Error>(handle_error)
- .define_singleton_method(
+ .define_singleton_function(
"_calculate_gain",
- *[](NonlinearityType nonlinearity, double param) {
+ [](NonlinearityType nonlinearity, double param) {
return torch::nn::init::calculate_gain(nonlinearity, param);
})
- .define_singleton_method(
+ .define_singleton_function(
"_uniform!",
- *[](Tensor tensor, double low, double high) {
+ [](Tensor tensor, double low, double high) {
return torch::nn::init::uniform_(tensor, low, high);
})
- .define_singleton_method(
+ .define_singleton_function(
"_normal!",
- *[](Tensor tensor, double mean, double std) {
+ [](Tensor tensor, double mean, double std) {
return torch::nn::init::normal_(tensor, mean, std);
})
- .define_singleton_method(
+ .define_singleton_function(
"_constant!",
- *[](Tensor tensor, Scalar value) {
+ [](Tensor tensor, Scalar value) {
return torch::nn::init::constant_(tensor, value);
})
- .define_singleton_method(
+ .define_singleton_function(
"_ones!",
- *[](Tensor tensor) {
+ [](Tensor tensor) {
return torch::nn::init::ones_(tensor);
})
- .define_singleton_method(
+ .define_singleton_function(
"_zeros!",
- *[](Tensor tensor) {
+ [](Tensor tensor) {
return torch::nn::init::zeros_(tensor);
})
- .define_singleton_method(
+ .define_singleton_function(
"_eye!",
- *[](Tensor tensor) {
+ [](Tensor tensor) {
return torch::nn::init::eye_(tensor);
})
- .define_singleton_method(
+ .define_singleton_function(
"_dirac!",
- *[](Tensor tensor) {
+ [](Tensor tensor) {
return torch::nn::init::dirac_(tensor);
})
- .define_singleton_method(
+ .define_singleton_function(
"_xavier_uniform!",
- *[](Tensor tensor, double gain) {
+ [](Tensor tensor, double gain) {
return torch::nn::init::xavier_uniform_(tensor, gain);
})
- .define_singleton_method(
+ .define_singleton_function(
"_xavier_normal!",
- *[](Tensor tensor, double gain) {
+ [](Tensor tensor, double gain) {
return torch::nn::init::xavier_normal_(tensor, gain);
})
- .define_singleton_method(
+ .define_singleton_function(
"_kaiming_uniform!",
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
})
- .define_singleton_method(
+ .define_singleton_function(
"_kaiming_normal!",
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
+ [](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
})
- .define_singleton_method(
+ .define_singleton_function(
"_orthogonal!",
- *[](Tensor tensor, double gain) {
+ [](Tensor tensor, double gain) {
return torch::nn::init::orthogonal_(tensor, gain);
})
- .define_singleton_method(
+ .define_singleton_function(
"_sparse!",
- *[](Tensor tensor, double sparsity, double std) {
+ [](Tensor tensor, double sparsity, double std) {
return torch::nn::init::sparse_(tensor, sparsity, std);
});
Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
.add_handler<torch::Error>(handle_error)
.define_method(
"grad",
- *[](Parameter& self) {
+ [](Parameter& self) {
auto grad = self.grad();
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
})
.define_method(
"grad=",
- *[](Parameter& self, torch::Tensor& grad) {
+ [](Parameter& self, torch::Tensor& grad) {
self.mutable_grad() = grad;
})
- .define_singleton_method(
+ .define_singleton_function(
"_make_subclass",
- *[](Tensor& rd, bool requires_grad) {
+ [](Tensor& rd, bool requires_grad) {
auto data = rd.detach();
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
auto var = data.set_requires_grad(requires_grad);
return Parameter(std::move(var));
});