ext/torch/ext.cpp in torch-rb-0.3.6 vs ext/torch/ext.cpp in torch-rb-0.3.7
- old
+ new
@@ -230,21 +230,21 @@
*[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
return torch::arange(start, end, step, options);
})
.define_singleton_method(
"_empty",
- *[](IntArrayRef size, const torch::TensorOptions &options) {
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
return torch::empty(size, options);
})
.define_singleton_method(
"_eye",
*[](int64_t m, int64_t n, const torch::TensorOptions &options) {
return torch::eye(m, n, options);
})
.define_singleton_method(
"_full",
- *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
+ *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
return torch::full(size, fill_value, options);
})
.define_singleton_method(
"_linspace",
*[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
@@ -255,36 +255,36 @@
*[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
return torch::logspace(start, end, steps, base, options);
})
.define_singleton_method(
"_ones",
- *[](IntArrayRef size, const torch::TensorOptions &options) {
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
return torch::ones(size, options);
})
.define_singleton_method(
"_rand",
- *[](IntArrayRef size, const torch::TensorOptions &options) {
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
return torch::rand(size, options);
})
.define_singleton_method(
"_randint",
- *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
+ *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
return torch::randint(low, high, size, options);
})
.define_singleton_method(
"_randn",
- *[](IntArrayRef size, const torch::TensorOptions &options) {
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
return torch::randn(size, options);
})
.define_singleton_method(
"_randperm",
*[](int64_t n, const torch::TensorOptions &options) {
return torch::randperm(n, options);
})
.define_singleton_method(
"_zeros",
- *[](IntArrayRef size, const torch::TensorOptions &options) {
+ *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
return torch::zeros(size, options);
})
// begin operations
.define_singleton_method(
"_save",
@@ -301,17 +301,17 @@
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
return torch::pickle_load(v);
})
.define_singleton_method(
"_from_blob",
- *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
+ *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
void *data = const_cast<char *>(s.c_str());
return torch::from_blob(data, size, options);
})
.define_singleton_method(
"_tensor",
- *[](Array a, IntArrayRef size, const torch::TensorOptions &options) {
+ *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
auto dtype = options.dtype();
torch::Tensor t;
if (dtype == torch::kBool) {
std::vector<uint8_t> vec;
for (size_t i = 0; i < a.size(); i++) {
@@ -340,11 +340,21 @@
.define_method("quantized?", &torch::Tensor::is_quantized)
.define_method("dim", &torch::Tensor::dim)
.define_method("numel", &torch::Tensor::numel)
.define_method("element_size", &torch::Tensor::element_size)
.define_method("requires_grad", &torch::Tensor::requires_grad)
+ // in C++ for performance
.define_method(
+ "shape",
+ *[](Tensor& self) {
+ Array a;
+ for (auto &size : self.sizes()) {
+ a.push(size);
+ }
+ return a;
+ })
+ .define_method(
"_index",
*[](Tensor& self, Array indices) {
auto vec = index_vector(indices);
return self.index(vec);
})
@@ -418,11 +428,21 @@
if (tensor.device().type() != torch::kCPU) {
torch::Device device("cpu");
tensor = tensor.to(device);
}
+ if (!tensor.is_contiguous()) {
+ tensor = tensor.contiguous();
+ }
+
auto data_ptr = (const char *) tensor.data_ptr();
return std::string(data_ptr, tensor.numel() * tensor.element_size());
+ })
+ // for TorchVision
+ .define_method(
+ "_data_ptr",
+ *[](Tensor& self) {
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
})
// TODO figure out a better way to do this
.define_method(
"_flat_data",
*[](Tensor& self) {