ext/torch/tensor.cpp in torch-rb-0.6.0 vs ext/torch/tensor.cpp in torch-rb-0.7.0
- old
+ new
@@ -1,18 +1,50 @@
#include <torch/torch.h>
-#include <rice/Constructor.hpp>
-#include <rice/Module.hpp>
+#include <rice/rice.hpp>
#include "tensor_functions.h"
#include "ruby_arg_parser.h"
#include "templates.h"
#include "utils.h"
using namespace Rice;
using torch::indexing::TensorIndex;
+namespace Rice::detail
+{
+ template<typename T>
+ struct Type<c10::complex<T>>
+ {
+ static bool verify()
+ {
+ return true;
+ }
+ };
+
+ template<typename T>
+ class To_Ruby<c10::complex<T>>
+ {
+ public:
+ VALUE convert(c10::complex<T> const& x)
+ {
+ return rb_dbl_complex_new(x.real(), x.imag());
+ }
+ };
+}
+
+template<typename T>
+Array flat_data(Tensor& tensor) {
+ Tensor view = tensor.reshape({tensor.numel()});
+
+ Array a;
+ for (int i = 0; i < tensor.numel(); i++) {
+ a.push(view[i].item().to<T>());
+ }
+ return a;
+}
+
Class rb_cTensor;
std::vector<TensorIndex> index_vector(Array a) {
Object obj;
@@ -21,23 +53,23 @@
for (long i = 0; i < a.size(); i++) {
obj = a[i];
if (obj.is_instance_of(rb_cInteger)) {
- indices.push_back(from_ruby<int64_t>(obj));
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
} else if (obj.is_instance_of(rb_cRange)) {
torch::optional<int64_t> start_index = torch::nullopt;
torch::optional<int64_t> stop_index = torch::nullopt;
Object begin = obj.call("begin");
if (!begin.is_nil()) {
- start_index = from_ruby<int64_t>(begin);
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
}
Object end = obj.call("end");
if (!end.is_nil()) {
- stop_index = from_ruby<int64_t>(end);
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
}
Object exclude_end = obj.call("exclude_end?");
if (stop_index.has_value() && !exclude_end) {
if (stop_index.value() == -1) {
@@ -47,15 +79,15 @@
}
}
indices.push_back(torch::indexing::Slice(start_index, stop_index));
} else if (obj.is_instance_of(rb_cTensor)) {
- indices.push_back(from_ruby<Tensor>(obj));
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
} else if (obj.is_nil()) {
indices.push_back(torch::indexing::None);
} else if (obj == True || obj == False) {
- indices.push_back(from_ruby<bool>(obj));
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
} else {
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
}
}
return indices;
@@ -66,11 +98,11 @@
// TODO add support for inputs argument
// _backward
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
{
HANDLE_TH_ERRORS
- Tensor& self = from_ruby<Tensor&>(self_);
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
static RubyArgParser parser({
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
});
ParsedArgs<4> parsed_args;
auto _r = parser.parse(self_, argc, argv, parsed_args);
@@ -82,12 +114,12 @@
dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
RETURN_NIL
END_HANDLE_TH_ERRORS
}
-void init_tensor(Rice::Module& m) {
- rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
+void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
+ rb_cTensor = c;
rb_cTensor.add_handler<torch::Error>(handle_error);
add_tensor_functions(rb_cTensor);
THPVariableClass = rb_cTensor.value();
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
@@ -100,98 +132,98 @@
.define_method("numel", &torch::Tensor::numel)
.define_method("element_size", &torch::Tensor::element_size)
.define_method("requires_grad", &torch::Tensor::requires_grad)
.define_method(
"_size",
- *[](Tensor& self, int64_t dim) {
+ [](Tensor& self, int64_t dim) {
return self.size(dim);
})
.define_method(
"_stride",
- *[](Tensor& self, int64_t dim) {
+ [](Tensor& self, int64_t dim) {
return self.stride(dim);
})
// in C++ for performance
.define_method(
"shape",
- *[](Tensor& self) {
+ [](Tensor& self) {
Array a;
for (auto &size : self.sizes()) {
a.push(size);
}
return a;
})
.define_method(
"_strides",
- *[](Tensor& self) {
+ [](Tensor& self) {
Array a;
for (auto &stride : self.strides()) {
a.push(stride);
}
return a;
})
.define_method(
"_index",
- *[](Tensor& self, Array indices) {
+ [](Tensor& self, Array indices) {
auto vec = index_vector(indices);
return self.index(vec);
})
.define_method(
"_index_put_custom",
- *[](Tensor& self, Array indices, torch::Tensor& value) {
+ [](Tensor& self, Array indices, torch::Tensor& value) {
auto vec = index_vector(indices);
return self.index_put_(vec, value);
})
.define_method(
"contiguous?",
- *[](Tensor& self) {
+ [](Tensor& self) {
return self.is_contiguous();
})
.define_method(
"_requires_grad!",
- *[](Tensor& self, bool requires_grad) {
+ [](Tensor& self, bool requires_grad) {
return self.set_requires_grad(requires_grad);
})
.define_method(
"grad",
- *[](Tensor& self) {
+ [](Tensor& self) {
auto grad = self.grad();
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
})
.define_method(
"grad=",
- *[](Tensor& self, torch::Tensor& grad) {
+ [](Tensor& self, torch::Tensor& grad) {
self.mutable_grad() = grad;
})
.define_method(
"_dtype",
- *[](Tensor& self) {
+ [](Tensor& self) {
return (int) at::typeMetaToScalarType(self.dtype());
})
.define_method(
"_type",
- *[](Tensor& self, int dtype) {
+ [](Tensor& self, int dtype) {
return self.toType((torch::ScalarType) dtype);
})
.define_method(
"_layout",
- *[](Tensor& self) {
+ [](Tensor& self) {
std::stringstream s;
s << self.layout();
return s.str();
})
.define_method(
"device",
- *[](Tensor& self) {
+ [](Tensor& self) {
std::stringstream s;
s << self.device();
return s.str();
})
.define_method(
"_data_str",
- *[](Tensor& self) {
- Tensor tensor = self;
+ [](Tensor& self) {
+ auto tensor = self;
// move to CPU to get data
if (tensor.device().type() != torch::kCPU) {
torch::Device device("cpu");
tensor = tensor.to(device);
@@ -205,85 +237,66 @@
return std::string(data_ptr, tensor.numel() * tensor.element_size());
})
// for TorchVision
.define_method(
"_data_ptr",
- *[](Tensor& self) {
+ [](Tensor& self) {
return reinterpret_cast<uintptr_t>(self.data_ptr());
})
// TODO figure out a better way to do this
.define_method(
"_flat_data",
- *[](Tensor& self) {
- Tensor tensor = self;
+ [](Tensor& self) {
+ auto tensor = self;
// move to CPU to get data
if (tensor.device().type() != torch::kCPU) {
torch::Device device("cpu");
tensor = tensor.to(device);
}
- Array a;
auto dtype = tensor.dtype();
-
- Tensor view = tensor.reshape({tensor.numel()});
-
- // TODO DRY if someone knows C++
if (dtype == torch::kByte) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<uint8_t>());
- }
+ return flat_data<uint8_t>(tensor);
} else if (dtype == torch::kChar) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(to_ruby<int>(view[i].item().to<int8_t>()));
- }
+ return flat_data<int8_t>(tensor);
} else if (dtype == torch::kShort) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<int16_t>());
- }
+ return flat_data<int16_t>(tensor);
} else if (dtype == torch::kInt) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<int32_t>());
- }
+ return flat_data<int32_t>(tensor);
} else if (dtype == torch::kLong) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<int64_t>());
- }
+ return flat_data<int64_t>(tensor);
} else if (dtype == torch::kFloat) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<float>());
- }
+ return flat_data<float>(tensor);
} else if (dtype == torch::kDouble) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<double>());
- }
+ return flat_data<double>(tensor);
} else if (dtype == torch::kBool) {
- for (int i = 0; i < tensor.numel(); i++) {
- a.push(view[i].item().to<bool>() ? True : False);
- }
+ return flat_data<bool>(tensor);
+ } else if (dtype == torch::kComplexFloat) {
+ return flat_data<c10::complex<float>>(tensor);
+ } else if (dtype == torch::kComplexDouble) {
+ return flat_data<c10::complex<double>>(tensor);
} else {
throw std::runtime_error("Unsupported type");
}
- return a;
})
.define_method(
"_to",
- *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
});
- Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
+ rb_cTensorOptions
.add_handler<torch::Error>(handle_error)
- .define_constructor(Rice::Constructor<torch::TensorOptions>())
.define_method(
"dtype",
- *[](torch::TensorOptions& self, int dtype) {
+ [](torch::TensorOptions& self, int dtype) {
return self.dtype((torch::ScalarType) dtype);
})
.define_method(
"layout",
- *[](torch::TensorOptions& self, const std::string& layout) {
+ [](torch::TensorOptions& self, const std::string& layout) {
torch::Layout l;
if (layout == "strided") {
l = torch::kStrided;
} else if (layout == "sparse") {
l = torch::kSparse;
@@ -293,15 +306,15 @@
}
return self.layout(l);
})
.define_method(
"device",
- *[](torch::TensorOptions& self, const std::string& device) {
+ [](torch::TensorOptions& self, const std::string& device) {
torch::Device d(device);
return self.device(d);
})
.define_method(
"requires_grad",
- *[](torch::TensorOptions& self, bool requires_grad) {
+ [](torch::TensorOptions& self, bool requires_grad) {
return self.requires_grad(requires_grad);
});
}