ext/torch/ext.cpp in torch-rb-0.3.2 vs ext/torch/ext.cpp in torch-rb-0.3.3
- old
+ new
@@ -14,10 +14,11 @@
#include "torch_functions.hpp"
#include "tensor_functions.hpp"
#include "nn_functions.hpp"
using namespace Rice;
+using torch::indexing::TensorIndex;
// need to make a distinction between parameters and tensors
class Parameter: public torch::autograd::Variable {
public:
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
@@ -26,10 +27,19 @@
void handle_error(torch::Error const & ex)
{
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
}
+std::vector<TensorIndex> index_vector(Array a) {
+ auto indices = std::vector<TensorIndex>();
+ indices.reserve(a.size());
+ for (size_t i = 0; i < a.size(); i++) {
+ indices.push_back(from_ruby<TensorIndex>(a[i]));
+ }
+ return indices;
+}
+
extern "C"
void Init_ext()
{
Module rb_mTorch = define_module("Torch");
rb_mTorch.add_handler<torch::Error>(handle_error);
@@ -56,10 +66,17 @@
// TODO set for CUDA when available
auto generator = at::detail::getDefaultCPUGenerator();
return generator.seed();
});
+ Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
+ .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
+ .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
+ .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
+ .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
+ .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
.add_handler<torch::Error>(handle_error)
.define_constructor(Constructor<torch::IValue>())
.define_method("bool?", &torch::IValue::isBool)
@@ -329,10 +346,22 @@
.define_method("dim", &torch::Tensor::dim)
.define_method("numel", &torch::Tensor::numel)
.define_method("element_size", &torch::Tensor::element_size)
.define_method("requires_grad", &torch::Tensor::requires_grad)
.define_method(
+ "_index",
+ *[](Tensor& self, Array indices) {
+ auto vec = index_vector(indices);
+ return self.index(vec);
+ })
+ .define_method(
+ "_index_put_custom",
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
+ auto vec = index_vector(indices);
+ return self.index_put_(vec, value);
+ })
+ .define_method(
"contiguous?",
*[](Tensor& self) {
return self.is_contiguous();
})
.define_method(
@@ -506,10 +535,11 @@
*[](torch::TensorOptions& self, bool requires_grad) {
return self.requires_grad(requires_grad);
});
Module rb_mInit = define_module_under(rb_mNN, "Init")
+ .add_handler<torch::Error>(handle_error)
.define_singleton_method(
"_calculate_gain",
*[](NonlinearityType nonlinearity, double param) {
return torch::nn::init::calculate_gain(nonlinearity, param);
})
@@ -592,11 +622,11 @@
*[](Parameter& self, torch::Tensor& grad) {
self.grad() = grad;
});
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
- .define_constructor(Constructor<torch::Device, std::string>())
.add_handler<torch::Error>(handle_error)
+ .define_constructor(Constructor<torch::Device, std::string>())
.define_method("index", &torch::Device::index)
.define_method("index?", &torch::Device::has_index)
.define_method(
"type",
*[](torch::Device& self) {