ext/torch/ext.cpp in torch-rb-0.3.7 vs ext/torch/ext.cpp in torch-rb-0.4.0

- old
+ new

@@ -5,17 +5,18 @@ #include <rice/Array.hpp> #include <rice/Class.hpp> #include <rice/Constructor.hpp> #include <rice/Hash.hpp> -#include "templates.hpp" +#include "templates.h" +#include "utils.h" // generated with: // rake generate:functions -#include "torch_functions.hpp" -#include "tensor_functions.hpp" -#include "nn_functions.hpp" +#include "torch_functions.h" +#include "tensor_functions.h" +#include "nn_functions.h" using namespace Rice; using torch::indexing::TensorIndex; // need to make a distinction between parameters and tensors @@ -27,15 +28,51 @@ void handle_error(torch::Error const & ex) { throw Exception(rb_eRuntimeError, ex.what_without_backtrace()); } +Class rb_cTensor; + std::vector<TensorIndex> index_vector(Array a) { - auto indices = std::vector<TensorIndex>(); + Object obj; + + std::vector<TensorIndex> indices; indices.reserve(a.size()); + for (size_t i = 0; i < a.size(); i++) { - indices.push_back(from_ruby<TensorIndex>(a[i])); + obj = a[i]; + + if (obj.is_instance_of(rb_cInteger)) { + indices.push_back(from_ruby<int64_t>(obj)); + } else if (obj.is_instance_of(rb_cRange)) { + torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin")); + torch::optional<int64_t> stop_index = -1; + + Object end = obj.call("end"); + if (!end.is_nil()) { + stop_index = from_ruby<int64_t>(end); + } + + Object exclude_end = obj.call("exclude_end?"); + if (!exclude_end) { + if (stop_index.value() == -1) { + stop_index = torch::nullopt; + } else { + stop_index = stop_index.value() + 1; + } + } + + indices.push_back(torch::indexing::Slice(start_index, stop_index)); + } else if (obj.is_instance_of(rb_cTensor)) { + indices.push_back(from_ruby<Tensor>(obj)); + } else if (obj.is_nil()) { + indices.push_back(torch::indexing::None); + } else if (obj == True || obj == False) { + indices.push_back(from_ruby<bool>(obj)); + } else { + throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj)); + } } return indices; } extern "C" @@ -43,13 +80,14 @@ { Module rb_mTorch = define_module("Torch"); rb_mTorch.add_handler<torch::Error>(handle_error); add_torch_functions(rb_mTorch); - Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor"); + rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor"); rb_cTensor.add_handler<torch::Error>(handle_error); add_tensor_functions(rb_cTensor); + THPVariableClass = rb_cTensor.value(); Module rb_mNN = define_module_under(rb_mTorch, "NN"); rb_mNN.add_handler<torch::Error>(handle_error); add_nn_functions(rb_mNN); @@ -66,17 +104,10 @@ // TODO set for CUDA when available auto generator = at::detail::getDefaultCPUGenerator(); return generator.seed(); }); - Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex") - .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); }) - .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); }) - .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); }) - .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); }) - .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); }); - // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue") .add_handler<torch::Error>(handle_error) .define_constructor(Constructor<torch::IValue>()) .define_method("bool?", &torch::IValue::isBool) @@ -222,71 +253,10 @@ .define_singleton_method( "parallel_info", *[] { return torch::get_parallel_info(); }) - // begin tensor creation - .define_singleton_method( - "_arange", - *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) { - return torch::arange(start, end, step, options); - }) - .define_singleton_method( - "_empty", - *[](std::vector<int64_t> size, const torch::TensorOptions &options) { - return torch::empty(size, options); - }) - .define_singleton_method( - "_eye", - *[](int64_t m, int64_t n, const torch::TensorOptions &options) { - return torch::eye(m, n, options); - }) - .define_singleton_method( - "_full", - *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) { - return torch::full(size, fill_value, options); - }) - .define_singleton_method( - "_linspace", - *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) { - return torch::linspace(start, end, steps, options); - }) - .define_singleton_method( - "_logspace", - *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) { - return torch::logspace(start, end, steps, base, options); - }) - .define_singleton_method( - "_ones", - *[](std::vector<int64_t> size, const torch::TensorOptions &options) { - return torch::ones(size, options); - }) - .define_singleton_method( - "_rand", - *[](std::vector<int64_t> size, const torch::TensorOptions &options) { - return torch::rand(size, options); - }) - .define_singleton_method( - "_randint", - *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) { - return torch::randint(low, high, size, options); - }) - .define_singleton_method( - "_randn", - *[](std::vector<int64_t> size, const torch::TensorOptions &options) { - return torch::randn(size, options); - }) - .define_singleton_method( - "_randperm", - *[](int64_t n, const torch::TensorOptions &options) { - return torch::randperm(n, options); - }) - .define_singleton_method( - "_zeros", - *[](std::vector<int64_t> size, const torch::TensorOptions &options) { - return torch::zeros(size, options); - }) // begin operations .define_singleton_method( "_save", *[](const torch::IValue &value) { auto v = torch::pickle_save(value); @@ -347,9 +317,18 @@ "shape", *[](Tensor& self) { Array a; for (auto &size : self.sizes()) { a.push(size); + } + return a; + }) + .define_method( + "_strides", + *[](Tensor& self) { + Array a; + for (auto &stride : self.strides()) { + a.push(stride); } return a; }) .define_method( "_index",