ext/torch/ext.cpp in torch-rb-0.2.4 vs ext/torch/ext.cpp in torch-rb-0.2.5

- old
+ new

@@ -21,28 +21,32 @@ class Parameter: public torch::autograd::Variable { public: Parameter(Tensor&& t) : torch::autograd::Variable(t) { } }; -void handle_error(c10::Error const & ex) +void handle_error(torch::Error const & ex) { throw Exception(rb_eRuntimeError, ex.what_without_backtrace()); } extern "C" void Init_ext() { Module rb_mTorch = define_module("Torch"); + rb_mTorch.add_handler<torch::Error>(handle_error); add_torch_functions(rb_mTorch); Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor"); + rb_cTensor.add_handler<torch::Error>(handle_error); add_tensor_functions(rb_cTensor); Module rb_mNN = define_module_under(rb_mTorch, "NN"); + rb_mNN.add_handler<torch::Error>(handle_error); add_nn_functions(rb_mNN); Module rb_mRandom = define_module_under(rb_mTorch, "Random") + .add_handler<torch::Error>(handle_error) .define_singleton_method( "initial_seed", *[]() { return at::detail::getDefaultCPUGenerator()->current_seed(); }) @@ -53,10 +57,11 @@ return at::detail::getDefaultCPUGenerator()->seed(); }); // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue") + .add_handler<torch::Error>(handle_error) .define_constructor(Constructor<torch::IValue>()) .define_method("bool?", &torch::IValue::isBool) .define_method("bool_list?", &torch::IValue::isBoolList) .define_method("capsule?", &torch::IValue::isCapsule) .define_method("custom_class?", &torch::IValue::isCustomClass) @@ -315,11 +320,10 @@ } return t.reshape(size); }); rb_cTensor - .add_handler<c10::Error>(handle_error) .define_method("cuda?", &torch::Tensor::is_cuda) .define_method("sparse?", &torch::Tensor::is_sparse) .define_method("quantized?", &torch::Tensor::is_quantized) .define_method("dim", &torch::Tensor::dim) .define_method("numel", &torch::Tensor::numel) @@ -373,10 +377,25 @@ std::stringstream s; s << self.device(); return s.str(); }) .define_method( + "_data_str", + *[](Tensor& self) { + Tensor tensor = self; + + // move to CPU to get data + if (tensor.device().type() != torch::kCPU) { + torch::Device device("cpu"); + tensor = tensor.to(device); + } + + auto data_ptr = (const char *) tensor.data_ptr(); + return std::string(data_ptr, tensor.numel() * tensor.element_size()); + }) + // TODO figure out a better way to do this + .define_method( "_flat_data", *[](Tensor& self) { Tensor tensor = self; // move to CPU to get data @@ -386,50 +405,44 @@ } Array a; auto dtype = tensor.dtype(); + Tensor view = tensor.reshape({tensor.numel()}); + // TODO DRY if someone knows C++ if (dtype == torch::kByte) { - uint8_t* data = tensor.data_ptr<uint8_t>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i]); + a.push(view[i].item().to<uint8_t>()); } } else if (dtype == torch::kChar) { - int8_t* data = tensor.data_ptr<int8_t>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(to_ruby<int>(data[i])); + a.push(to_ruby<int>(view[i].item().to<int8_t>())); } } else if (dtype == torch::kShort) { - int16_t* data = tensor.data_ptr<int16_t>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i]); + a.push(view[i].item().to<int16_t>()); } } else if (dtype == torch::kInt) { - int32_t* data = tensor.data_ptr<int32_t>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i]); + a.push(view[i].item().to<int32_t>()); } } else if (dtype == torch::kLong) { - int64_t* data = tensor.data_ptr<int64_t>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i]); + a.push(view[i].item().to<int64_t>()); } } else if (dtype == torch::kFloat) { - float* data = tensor.data_ptr<float>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i]); + a.push(view[i].item().to<float>()); } } else if (dtype == torch::kDouble) { - double* data = tensor.data_ptr<double>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i]); + a.push(view[i].item().to<double>()); } } else if (dtype == torch::kBool) { - bool* data = tensor.data_ptr<bool>(); for (int i = 0; i < tensor.numel(); i++) { - a.push(data[i] ? True : False); + a.push(view[i].item().to<bool>() ? True : False); } } else { throw std::runtime_error("Unsupported type"); } return a; @@ -447,11 +460,11 @@ auto var = data.set_requires_grad(requires_grad); return Parameter(std::move(var)); }); Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions") - .add_handler<c10::Error>(handle_error) + .add_handler<torch::Error>(handle_error) .define_constructor(Constructor<torch::TensorOptions>()) .define_method( "dtype", *[](torch::TensorOptions& self, int dtype) { return self.dtype((torch::ScalarType) dtype); @@ -553,19 +566,21 @@ *[](Tensor tensor, double sparsity, double std) { return torch::nn::init::sparse_(tensor, sparsity, std); }); Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter") + .add_handler<torch::Error>(handle_error) .define_method( "grad", *[](Parameter& self) { auto grad = self.grad(); return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil; }); Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device") .define_constructor(Constructor<torch::Device, std::string>()) + .add_handler<torch::Error>(handle_error) .define_method("index", &torch::Device::index) .define_method("index?", &torch::Device::has_index) .define_method( "type", *[](torch::Device& self) { @@ -573,8 +588,9 @@ s << self.type(); return s.str(); }); Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA") + .add_handler<torch::Error>(handle_error) .define_singleton_method("available?", &torch::cuda::is_available) .define_singleton_method("device_count", &torch::cuda::device_count); }