ext/torch/ext.cpp in torch-rb-0.3.7 vs ext/torch/ext.cpp in torch-rb-0.4.0
- old
+ new
@@ -5,17 +5,18 @@
#include <rice/Array.hpp>
#include <rice/Class.hpp>
#include <rice/Constructor.hpp>
#include <rice/Hash.hpp>
-#include "templates.hpp"
+#include "templates.h"
+#include "utils.h"
// generated with:
// rake generate:functions
-#include "torch_functions.hpp"
-#include "tensor_functions.hpp"
-#include "nn_functions.hpp"
+#include "torch_functions.h"
+#include "tensor_functions.h"
+#include "nn_functions.h"
using namespace Rice;
using torch::indexing::TensorIndex;
// need to make a distinction between parameters and tensors
@@ -27,15 +28,51 @@
void handle_error(torch::Error const & ex)
{
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
}
+Class rb_cTensor;
+
std::vector<TensorIndex> index_vector(Array a) {
- auto indices = std::vector<TensorIndex>();
+ Object obj;
+
+ std::vector<TensorIndex> indices;
indices.reserve(a.size());
+
for (size_t i = 0; i < a.size(); i++) {
- indices.push_back(from_ruby<TensorIndex>(a[i]));
+ obj = a[i];
+
+ if (obj.is_instance_of(rb_cInteger)) {
+ indices.push_back(from_ruby<int64_t>(obj));
+ } else if (obj.is_instance_of(rb_cRange)) {
+ torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
+ torch::optional<int64_t> stop_index = -1;
+
+ Object end = obj.call("end");
+ if (!end.is_nil()) {
+ stop_index = from_ruby<int64_t>(end);
+ }
+
+ Object exclude_end = obj.call("exclude_end?");
+ if (!exclude_end) {
+ if (stop_index.value() == -1) {
+ stop_index = torch::nullopt;
+ } else {
+ stop_index = stop_index.value() + 1;
+ }
+ }
+
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
+ } else if (obj.is_instance_of(rb_cTensor)) {
+ indices.push_back(from_ruby<Tensor>(obj));
+ } else if (obj.is_nil()) {
+ indices.push_back(torch::indexing::None);
+ } else if (obj == True || obj == False) {
+ indices.push_back(from_ruby<bool>(obj));
+ } else {
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
+ }
}
return indices;
}
extern "C"
@@ -43,13 +80,14 @@
{
Module rb_mTorch = define_module("Torch");
rb_mTorch.add_handler<torch::Error>(handle_error);
add_torch_functions(rb_mTorch);
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
+ rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
rb_cTensor.add_handler<torch::Error>(handle_error);
add_tensor_functions(rb_cTensor);
+ THPVariableClass = rb_cTensor.value();
Module rb_mNN = define_module_under(rb_mTorch, "NN");
rb_mNN.add_handler<torch::Error>(handle_error);
add_nn_functions(rb_mNN);
@@ -66,17 +104,10 @@
// TODO set for CUDA when available
auto generator = at::detail::getDefaultCPUGenerator();
return generator.seed();
});
- Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
- .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
- .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
- .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
- .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
- .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
-
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
.add_handler<torch::Error>(handle_error)
.define_constructor(Constructor<torch::IValue>())
.define_method("bool?", &torch::IValue::isBool)
@@ -222,71 +253,10 @@
.define_singleton_method(
"parallel_info",
*[] {
return torch::get_parallel_info();
})
- // begin tensor creation
- .define_singleton_method(
- "_arange",
- *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
- return torch::arange(start, end, step, options);
- })
- .define_singleton_method(
- "_empty",
- *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
- return torch::empty(size, options);
- })
- .define_singleton_method(
- "_eye",
- *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
- return torch::eye(m, n, options);
- })
- .define_singleton_method(
- "_full",
- *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
- return torch::full(size, fill_value, options);
- })
- .define_singleton_method(
- "_linspace",
- *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
- return torch::linspace(start, end, steps, options);
- })
- .define_singleton_method(
- "_logspace",
- *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
- return torch::logspace(start, end, steps, base, options);
- })
- .define_singleton_method(
- "_ones",
- *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
- return torch::ones(size, options);
- })
- .define_singleton_method(
- "_rand",
- *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
- return torch::rand(size, options);
- })
- .define_singleton_method(
- "_randint",
- *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
- return torch::randint(low, high, size, options);
- })
- .define_singleton_method(
- "_randn",
- *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
- return torch::randn(size, options);
- })
- .define_singleton_method(
- "_randperm",
- *[](int64_t n, const torch::TensorOptions &options) {
- return torch::randperm(n, options);
- })
- .define_singleton_method(
- "_zeros",
- *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
- return torch::zeros(size, options);
- })
// begin operations
.define_singleton_method(
"_save",
*[](const torch::IValue &value) {
auto v = torch::pickle_save(value);
@@ -347,9 +317,18 @@
"shape",
*[](Tensor& self) {
Array a;
for (auto &size : self.sizes()) {
a.push(size);
+ }
+ return a;
+ })
+ .define_method(
+ "_strides",
+ *[](Tensor& self) {
+ Array a;
+ for (auto &stride : self.strides()) {
+ a.push(stride);
}
return a;
})
.define_method(
"_index",