ext/torch/ext.cpp in torch-rb-0.3.2 vs ext/torch/ext.cpp in torch-rb-0.3.3

- old
+ new

@@ -14,10 +14,11 @@ #include "torch_functions.hpp" #include "tensor_functions.hpp" #include "nn_functions.hpp" using namespace Rice; +using torch::indexing::TensorIndex; // need to make a distinction between parameters and tensors class Parameter: public torch::autograd::Variable { public: Parameter(Tensor&& t) : torch::autograd::Variable(t) { } @@ -26,10 +27,19 @@ void handle_error(torch::Error const & ex) { throw Exception(rb_eRuntimeError, ex.what_without_backtrace()); } +std::vector<TensorIndex> index_vector(Array a) { + auto indices = std::vector<TensorIndex>(); + indices.reserve(a.size()); + for (size_t i = 0; i < a.size(); i++) { + indices.push_back(from_ruby<TensorIndex>(a[i])); + } + return indices; +} + extern "C" void Init_ext() { Module rb_mTorch = define_module("Torch"); rb_mTorch.add_handler<torch::Error>(handle_error); @@ -56,10 +66,17 @@ // TODO set for CUDA when available auto generator = at::detail::getDefaultCPUGenerator(); return generator.seed(); }); + Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex") + .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); }) + .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); }) + .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); }) + .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); }) + .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); }); + // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue") .add_handler<torch::Error>(handle_error) .define_constructor(Constructor<torch::IValue>()) .define_method("bool?", &torch::IValue::isBool) @@ -329,10 +346,22 @@ .define_method("dim", &torch::Tensor::dim) .define_method("numel", &torch::Tensor::numel) .define_method("element_size", &torch::Tensor::element_size) .define_method("requires_grad", &torch::Tensor::requires_grad) .define_method( + "_index", + *[](Tensor& self, Array indices) { + auto vec = index_vector(indices); + return self.index(vec); + }) + .define_method( + "_index_put_custom", + *[](Tensor& self, Array indices, torch::Tensor& value) { + auto vec = index_vector(indices); + return self.index_put_(vec, value); + }) + .define_method( "contiguous?", *[](Tensor& self) { return self.is_contiguous(); }) .define_method( @@ -506,10 +535,11 @@ *[](torch::TensorOptions& self, bool requires_grad) { return self.requires_grad(requires_grad); }); Module rb_mInit = define_module_under(rb_mNN, "Init") + .add_handler<torch::Error>(handle_error) .define_singleton_method( "_calculate_gain", *[](NonlinearityType nonlinearity, double param) { return torch::nn::init::calculate_gain(nonlinearity, param); }) @@ -592,11 +622,11 @@ *[](Parameter& self, torch::Tensor& grad) { self.grad() = grad; }); Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device") - .define_constructor(Constructor<torch::Device, std::string>()) .add_handler<torch::Error>(handle_error) + .define_constructor(Constructor<torch::Device, std::string>()) .define_method("index", &torch::Device::index) .define_method("index?", &torch::Device::has_index) .define_method( "type", *[](torch::Device& self) {