ext/torch/torch.cpp in torch-rb-0.13.2 vs ext/torch/torch.cpp in torch-rb-0.14.0
- old
+ new
@@ -1,9 +1,11 @@
#include <torch/torch.h>
#include <rice/rice.hpp>
+#include <fstream>
+
#include "torch_functions.h"
#include "templates.h"
#include "utils.h"
template<typename T>
@@ -55,19 +57,21 @@
// begin operations
.define_singleton_function(
"_save",
[](const torch::IValue &value) {
auto v = torch::pickle_save(value);
- std::string str(v.begin(), v.end());
- return str;
+ return Object(rb_str_new(v.data(), v.size()));
})
.define_singleton_function(
"_load",
- [](const std::string &s) {
- std::vector<char> v;
- std::copy(s.begin(), s.end(), std::back_inserter(v));
+ [](const std::string &filename) {
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
- return torch::pickle_load(v);
+ std::ifstream input(filename, std::ios::binary);
+ std::vector<char> bytes(
+ (std::istreambuf_iterator<char>(input)),
+ (std::istreambuf_iterator<char>()));
+ input.close();
+ return torch::pickle_load(bytes);
})
.define_singleton_function(
"_from_blob",
[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
void *data = const_cast<char *>(s.c_str());