ext/torch/tensor.cpp in torch-rb-0.6.0 vs ext/torch/tensor.cpp in torch-rb-0.7.0

- old
+ new

@@ -1,18 +1,50 @@ #include <torch/torch.h> -#include <rice/Constructor.hpp> -#include <rice/Module.hpp> +#include <rice/rice.hpp> #include "tensor_functions.h" #include "ruby_arg_parser.h" #include "templates.h" #include "utils.h" using namespace Rice; using torch::indexing::TensorIndex; +namespace Rice::detail +{ + template<typename T> + struct Type<c10::complex<T>> + { + static bool verify() + { + return true; + } + }; + + template<typename T> + class To_Ruby<c10::complex<T>> + { + public: + VALUE convert(c10::complex<T> const& x) + { + return rb_dbl_complex_new(x.real(), x.imag()); + } + }; +} + +template<typename T> +Array flat_data(Tensor& tensor) { + Tensor view = tensor.reshape({tensor.numel()}); + + Array a; + for (int i = 0; i < tensor.numel(); i++) { + a.push(view[i].item().to<T>()); + } + return a; +} + Class rb_cTensor; std::vector<TensorIndex> index_vector(Array a) { Object obj; @@ -21,23 +53,23 @@ for (long i = 0; i < a.size(); i++) { obj = a[i]; if (obj.is_instance_of(rb_cInteger)) { - indices.push_back(from_ruby<int64_t>(obj)); + indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value())); } else if (obj.is_instance_of(rb_cRange)) { torch::optional<int64_t> start_index = torch::nullopt; torch::optional<int64_t> stop_index = torch::nullopt; Object begin = obj.call("begin"); if (!begin.is_nil()) { - start_index = from_ruby<int64_t>(begin); + start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value()); } Object end = obj.call("end"); if (!end.is_nil()) { - stop_index = from_ruby<int64_t>(end); + stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value()); } Object exclude_end = obj.call("exclude_end?"); if (stop_index.has_value() && !exclude_end) { if (stop_index.value() == -1) { @@ -47,15 +79,15 @@ } } indices.push_back(torch::indexing::Slice(start_index, stop_index)); } else if (obj.is_instance_of(rb_cTensor)) { - indices.push_back(from_ruby<Tensor>(obj)); + indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value())); } else if (obj.is_nil()) { indices.push_back(torch::indexing::None); } else if (obj == True || obj == False) { - indices.push_back(from_ruby<bool>(obj)); + indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value())); } else { throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj)); } } return indices; @@ -66,11 +98,11 @@ // TODO add support for inputs argument // _backward static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_) { HANDLE_TH_ERRORS - Tensor& self = from_ruby<Tensor&>(self_); + Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_); static RubyArgParser parser({ "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)" }); ParsedArgs<4> parsed_args; auto _r = parser.parse(self_, argc, argv, parsed_args); @@ -82,12 +114,12 @@ dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2)); RETURN_NIL END_HANDLE_TH_ERRORS } -void init_tensor(Rice::Module& m) { - rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor"); +void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) { + rb_cTensor = c; rb_cTensor.add_handler<torch::Error>(handle_error); add_tensor_functions(rb_cTensor); THPVariableClass = rb_cTensor.value(); rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1); @@ -100,98 +132,98 @@ .define_method("numel", &torch::Tensor::numel) .define_method("element_size", &torch::Tensor::element_size) .define_method("requires_grad", &torch::Tensor::requires_grad) .define_method( "_size", - *[](Tensor& self, int64_t dim) { + [](Tensor& self, int64_t dim) { return self.size(dim); }) .define_method( "_stride", - *[](Tensor& self, int64_t dim) { + [](Tensor& self, int64_t dim) { return self.stride(dim); }) // in C++ for performance .define_method( "shape", - *[](Tensor& self) { + [](Tensor& self) { Array a; for (auto &size : self.sizes()) { a.push(size); } return a; }) .define_method( "_strides", - *[](Tensor& self) { + [](Tensor& self) { Array a; for (auto &stride : self.strides()) { a.push(stride); } return a; }) .define_method( "_index", - *[](Tensor& self, Array indices) { + [](Tensor& self, Array indices) { auto vec = index_vector(indices); return self.index(vec); }) .define_method( "_index_put_custom", - *[](Tensor& self, Array indices, torch::Tensor& value) { + [](Tensor& self, Array indices, torch::Tensor& value) { auto vec = index_vector(indices); return self.index_put_(vec, value); }) .define_method( "contiguous?", - *[](Tensor& self) { + [](Tensor& self) { return self.is_contiguous(); }) .define_method( "_requires_grad!", - *[](Tensor& self, bool requires_grad) { + [](Tensor& self, bool requires_grad) { return self.set_requires_grad(requires_grad); }) .define_method( "grad", - *[](Tensor& self) { + [](Tensor& self) { auto grad = self.grad(); - return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil; + return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil; }) .define_method( "grad=", - *[](Tensor& self, torch::Tensor& grad) { + [](Tensor& self, torch::Tensor& grad) { self.mutable_grad() = grad; }) .define_method( "_dtype", - *[](Tensor& self) { + [](Tensor& self) { return (int) at::typeMetaToScalarType(self.dtype()); }) .define_method( "_type", - *[](Tensor& self, int dtype) { + [](Tensor& self, int dtype) { return self.toType((torch::ScalarType) dtype); }) .define_method( "_layout", - *[](Tensor& self) { + [](Tensor& self) { std::stringstream s; s << self.layout(); return s.str(); }) .define_method( "device", - *[](Tensor& self) { + [](Tensor& self) { std::stringstream s; s << self.device(); return s.str(); }) .define_method( "_data_str", - *[](Tensor& self) { - Tensor tensor = self; + [](Tensor& self) { + auto tensor = self; // move to CPU to get data if (tensor.device().type() != torch::kCPU) { torch::Device device("cpu"); tensor = tensor.to(device); @@ -205,85 +237,66 @@ return std::string(data_ptr, tensor.numel() * tensor.element_size()); }) // for TorchVision .define_method( "_data_ptr", - *[](Tensor& self) { + [](Tensor& self) { return reinterpret_cast<uintptr_t>(self.data_ptr()); }) // TODO figure out a better way to do this .define_method( "_flat_data", - *[](Tensor& self) { - Tensor tensor = self; + [](Tensor& self) { + auto tensor = self; // move to CPU to get data if (tensor.device().type() != torch::kCPU) { torch::Device device("cpu"); tensor = tensor.to(device); } - Array a; auto dtype = tensor.dtype(); - - Tensor view = tensor.reshape({tensor.numel()}); - - // TODO DRY if someone knows C++ if (dtype == torch::kByte) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<uint8_t>()); - } + return flat_data<uint8_t>(tensor); } else if (dtype == torch::kChar) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(to_ruby<int>(view[i].item().to<int8_t>())); - } + return flat_data<int8_t>(tensor); } else if (dtype == torch::kShort) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<int16_t>()); - } + return flat_data<int16_t>(tensor); } else if (dtype == torch::kInt) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<int32_t>()); - } + return flat_data<int32_t>(tensor); } else if (dtype == torch::kLong) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<int64_t>()); - } + return flat_data<int64_t>(tensor); } else if (dtype == torch::kFloat) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<float>()); - } + return flat_data<float>(tensor); } else if (dtype == torch::kDouble) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<double>()); - } + return flat_data<double>(tensor); } else if (dtype == torch::kBool) { - for (int i = 0; i < tensor.numel(); i++) { - a.push(view[i].item().to<bool>() ? True : False); - } + return flat_data<bool>(tensor); + } else if (dtype == torch::kComplexFloat) { + return flat_data<c10::complex<float>>(tensor); + } else if (dtype == torch::kComplexDouble) { + return flat_data<c10::complex<double>>(tensor); } else { throw std::runtime_error("Unsupported type"); } - return a; }) .define_method( "_to", - *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) { + [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) { return self.to(device, (torch::ScalarType) dtype, non_blocking, copy); }); - Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions") + rb_cTensorOptions .add_handler<torch::Error>(handle_error) - .define_constructor(Rice::Constructor<torch::TensorOptions>()) .define_method( "dtype", - *[](torch::TensorOptions& self, int dtype) { + [](torch::TensorOptions& self, int dtype) { return self.dtype((torch::ScalarType) dtype); }) .define_method( "layout", - *[](torch::TensorOptions& self, const std::string& layout) { + [](torch::TensorOptions& self, const std::string& layout) { torch::Layout l; if (layout == "strided") { l = torch::kStrided; } else if (layout == "sparse") { l = torch::kSparse; @@ -293,15 +306,15 @@ } return self.layout(l); }) .define_method( "device", - *[](torch::TensorOptions& self, const std::string& device) { + [](torch::TensorOptions& self, const std::string& device) { torch::Device d(device); return self.device(d); }) .define_method( "requires_grad", - *[](torch::TensorOptions& self, bool requires_grad) { + [](torch::TensorOptions& self, bool requires_grad) { return self.requires_grad(requires_grad); }); }