ext/torch/ext.cpp in torch-rb-0.3.6 vs ext/torch/ext.cpp in torch-rb-0.3.7

- old
+ new

@@ -230,21 +230,21 @@ *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) { return torch::arange(start, end, step, options); }) .define_singleton_method( "_empty", - *[](IntArrayRef size, const torch::TensorOptions &options) { + *[](std::vector<int64_t> size, const torch::TensorOptions &options) { return torch::empty(size, options); }) .define_singleton_method( "_eye", *[](int64_t m, int64_t n, const torch::TensorOptions &options) { return torch::eye(m, n, options); }) .define_singleton_method( "_full", - *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) { + *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) { return torch::full(size, fill_value, options); }) .define_singleton_method( "_linspace", *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) { @@ -255,36 +255,36 @@ *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) { return torch::logspace(start, end, steps, base, options); }) .define_singleton_method( "_ones", - *[](IntArrayRef size, const torch::TensorOptions &options) { + *[](std::vector<int64_t> size, const torch::TensorOptions &options) { return torch::ones(size, options); }) .define_singleton_method( "_rand", - *[](IntArrayRef size, const torch::TensorOptions &options) { + *[](std::vector<int64_t> size, const torch::TensorOptions &options) { return torch::rand(size, options); }) .define_singleton_method( "_randint", - *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) { + *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) { return torch::randint(low, high, size, options); }) .define_singleton_method( "_randn", - *[](IntArrayRef size, const torch::TensorOptions &options) { + *[](std::vector<int64_t> size, const torch::TensorOptions &options) { return torch::randn(size, options); }) .define_singleton_method( "_randperm", *[](int64_t n, const torch::TensorOptions &options) { return torch::randperm(n, options); }) .define_singleton_method( "_zeros", - *[](IntArrayRef size, const torch::TensorOptions &options) { + *[](std::vector<int64_t> size, const torch::TensorOptions &options) { return torch::zeros(size, options); }) // begin operations .define_singleton_method( "_save", @@ -301,17 +301,17 @@ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701 return torch::pickle_load(v); }) .define_singleton_method( "_from_blob", - *[](String s, IntArrayRef size, const torch::TensorOptions &options) { + *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) { void *data = const_cast<char *>(s.c_str()); return torch::from_blob(data, size, options); }) .define_singleton_method( "_tensor", - *[](Array a, IntArrayRef size, const torch::TensorOptions &options) { + *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) { auto dtype = options.dtype(); torch::Tensor t; if (dtype == torch::kBool) { std::vector<uint8_t> vec; for (size_t i = 0; i < a.size(); i++) { @@ -340,11 +340,21 @@ .define_method("quantized?", &torch::Tensor::is_quantized) .define_method("dim", &torch::Tensor::dim) .define_method("numel", &torch::Tensor::numel) .define_method("element_size", &torch::Tensor::element_size) .define_method("requires_grad", &torch::Tensor::requires_grad) + // in C++ for performance .define_method( + "shape", + *[](Tensor& self) { + Array a; + for (auto &size : self.sizes()) { + a.push(size); + } + return a; + }) + .define_method( "_index", *[](Tensor& self, Array indices) { auto vec = index_vector(indices); return self.index(vec); }) @@ -418,11 +428,21 @@ if (tensor.device().type() != torch::kCPU) { torch::Device device("cpu"); tensor = tensor.to(device); } + if (!tensor.is_contiguous()) { + tensor = tensor.contiguous(); + } + auto data_ptr = (const char *) tensor.data_ptr(); return std::string(data_ptr, tensor.numel() * tensor.element_size()); + }) + // for TorchVision + .define_method( + "_data_ptr", + *[](Tensor& self) { + return reinterpret_cast<uintptr_t>(self.data_ptr()); }) // TODO figure out a better way to do this .define_method( "_flat_data", *[](Tensor& self) {