ext/torch/ext.cpp in torch-rb-0.2.0 vs ext/torch/ext.cpp in torch-rb-0.2.1

- old
+ new

@@ -3,10 +3,11 @@ #include <torch/torch.h> #include <rice/Array.hpp> #include <rice/Class.hpp> #include <rice/Constructor.hpp> +#include <rice/Hash.hpp> #include "templates.hpp" // generated with: // rake generate:functions @@ -20,10 +21,15 @@ class Parameter: public torch::autograd::Variable { public: Parameter(Tensor&& t) : torch::autograd::Variable(t) { } }; +void handle_error(c10::Error const & ex) +{ + throw Exception(rb_eRuntimeError, ex.what_without_backtrace()); +} + extern "C" void Init_ext() { Module rb_mTorch = define_module("Torch"); add_torch_functions(rb_mTorch); @@ -32,10 +38,112 @@ add_tensor_functions(rb_cTensor); Module rb_mNN = define_module_under(rb_mTorch, "NN"); add_nn_functions(rb_mNN); + // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html + Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue") + .define_constructor(Constructor<torch::IValue>()) + .define_method("bool?", &torch::IValue::isBool) + .define_method("bool_list?", &torch::IValue::isBoolList) + .define_method("capsule?", &torch::IValue::isCapsule) + .define_method("custom_class?", &torch::IValue::isCustomClass) + .define_method("device?", &torch::IValue::isDevice) + .define_method("double?", &torch::IValue::isDouble) + .define_method("double_list?", &torch::IValue::isDoubleList) + .define_method("future?", &torch::IValue::isFuture) + // .define_method("generator?", &torch::IValue::isGenerator) + .define_method("generic_dict?", &torch::IValue::isGenericDict) + .define_method("list?", &torch::IValue::isList) + .define_method("int?", &torch::IValue::isInt) + .define_method("int_list?", &torch::IValue::isIntList) + .define_method("module?", &torch::IValue::isModule) + .define_method("none?", &torch::IValue::isNone) + .define_method("object?", &torch::IValue::isObject) + .define_method("ptr_type?", &torch::IValue::isPtrType) + .define_method("py_object?", &torch::IValue::isPyObject) + .define_method("r_ref?", &torch::IValue::isRRef) + .define_method("scalar?", &torch::IValue::isScalar) + .define_method("string?", &torch::IValue::isString) + .define_method("tensor?", &torch::IValue::isTensor) + .define_method("tensor_list?", &torch::IValue::isTensorList) + .define_method("tuple?", &torch::IValue::isTuple) + .define_method( + "to_bool", + *[](torch::IValue& self) { + return self.toBool(); + }) + .define_method( + "to_double", + *[](torch::IValue& self) { + return self.toDouble(); + }) + .define_method( + "to_int", + *[](torch::IValue& self) { + return self.toInt(); + }) + .define_method( + "to_string_ref", + *[](torch::IValue& self) { + return self.toStringRef(); + }) + .define_method( + "to_tensor", + *[](torch::IValue& self) { + return self.toTensor(); + }) + .define_method( + "to_generic_dict", + *[](torch::IValue& self) { + auto dict = self.toGenericDict(); + Hash h; + for (auto& pair : dict) { + h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()}); + } + return h; + }) + .define_singleton_method( + "from_tensor", + *[](torch::Tensor& v) { + return torch::IValue(v); + }) + .define_singleton_method( + "from_string", + *[](String v) { + return torch::IValue(v.str()); + }) + .define_singleton_method( + "from_int", + *[](int64_t v) { + return torch::IValue(v); + }) + .define_singleton_method( + "from_double", + *[](double v) { + return torch::IValue(v); + }) + .define_singleton_method( + "from_bool", + *[](bool v) { + return torch::IValue(v); + }) + // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h + // createGenericDict and toIValue + .define_singleton_method( + "from_dict", + *[](Hash obj) { + auto key_type = c10::AnyType::get(); + auto value_type = c10::AnyType::get(); + c10::impl::GenericDict elems(key_type, value_type); + elems.reserve(obj.size()); + for (auto entry : obj) { + elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second)); + } + return torch::IValue(std::move(elems)); + }); + rb_mTorch.define_singleton_method( "grad_enabled?", *[]() { return torch::GradMode::is_enabled(); }) @@ -111,16 +219,24 @@ return torch::zeros(size, options); }) // begin operations .define_singleton_method( "_save", - *[](const Tensor &value) { + *[](const torch::IValue &value) { auto v = torch::pickle_save(value); std::string str(v.begin(), v.end()); return str; }) .define_singleton_method( + "_load", + *[](const std::string &s) { + std::vector<char> v; + std::copy(s.begin(), s.end(), std::back_inserter(v)); + // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701 + return torch::pickle_load(v); + }) + .define_singleton_method( "_binary_cross_entropy_with_logits", *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) { return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction); }) .define_singleton_method( @@ -155,10 +271,11 @@ } return t.reshape(size); }); rb_cTensor + .add_handler<c10::Error>(handle_error) .define_method("cuda?", &torch::Tensor::is_cuda) .define_method("sparse?", &torch::Tensor::is_sparse) .define_method("quantized?", &torch::Tensor::is_quantized) .define_method("dim", &torch::Tensor::dim) .define_method("numel", &torch::Tensor::numel) @@ -286,10 +403,11 @@ auto var = data.set_requires_grad(requires_grad); return Parameter(std::move(var)); }); Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions") + .add_handler<c10::Error>(handle_error) .define_constructor(Constructor<torch::TensorOptions>()) .define_method( "dtype", *[](torch::TensorOptions& self, int dtype) { return self.dtype((torch::ScalarType) dtype); @@ -309,16 +427,11 @@ return self.layout(l); }) .define_method( "device", *[](torch::TensorOptions& self, std::string device) { - try { - // needed to catch exception - torch::Device d(device); - return self.device(d); - } catch (const c10::Error& error) { - throw std::runtime_error(error.what_without_backtrace()); - } + torch::Device d(device); + return self.device(d); }) .define_method( "requires_grad", *[](torch::TensorOptions& self, bool requires_grad) { return self.requires_grad(requires_grad);