ext/torch/ivalue.cpp in torch-rb-0.6.0 vs ext/torch/ivalue.cpp in torch-rb-0.7.0

- old
+ new

@@ -1,20 +1,15 @@ #include <torch/torch.h> -#include <rice/Array.hpp> -#include <rice/Constructor.hpp> -#include <rice/Hash.hpp> -#include <rice/Module.hpp> -#include <rice/String.hpp> +#include <rice/rice.hpp> #include "utils.h" -void init_ivalue(Rice::Module& m) { +void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) { // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html - Rice::define_class_under<torch::IValue>(m, "IValue") + rb_cIValue .add_handler<torch::Error>(handle_error) - .define_constructor(Rice::Constructor<torch::IValue>()) .define_method("bool?", &torch::IValue::isBool) .define_method("bool_list?", &torch::IValue::isBoolList) .define_method("capsule?", &torch::IValue::isCapsule) .define_method("custom_class?", &torch::IValue::isCustomClass) .define_method("device?", &torch::IValue::isDevice) @@ -37,98 +32,101 @@ .define_method("tensor?", &torch::IValue::isTensor) .define_method("tensor_list?", &torch::IValue::isTensorList) .define_method("tuple?", &torch::IValue::isTuple) .define_method( "to_bool", - *[](torch::IValue& self) { + [](torch::IValue& self) { return self.toBool(); }) .define_method( "to_double", - *[](torch::IValue& self) { + [](torch::IValue& self) { return self.toDouble(); }) .define_method( "to_int", - *[](torch::IValue& self) { + [](torch::IValue& self) { return self.toInt(); }) .define_method( "to_list", - *[](torch::IValue& self) { + [](torch::IValue& self) { auto list = self.toListRef(); Rice::Array obj; for (auto& elem : list) { - obj.push(to_ruby<torch::IValue>(torch::IValue{elem})); + auto v = torch::IValue{elem}; + obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v))); } return obj; }) .define_method( "to_string_ref", - *[](torch::IValue& self) { + [](torch::IValue& self) { return self.toStringRef(); }) .define_method( "to_tensor", - *[](torch::IValue& self) { + [](torch::IValue& self) { return self.toTensor(); }) .define_method( "to_generic_dict", - *[](torch::IValue& self) { + [](torch::IValue& self) { auto dict = self.toGenericDict(); Rice::Hash obj; for (auto& pair : dict) { - obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()}); + auto k = torch::IValue{pair.key()}; + auto v = torch::IValue{pair.value()}; + obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)); } return obj; }) - .define_singleton_method( + .define_singleton_function( "from_tensor", - *[](torch::Tensor& v) { + [](torch::Tensor& v) { return torch::IValue(v); }) // TODO create specialized list types? - .define_singleton_method( + .define_singleton_function( "from_list", - *[](Rice::Array obj) { + [](Rice::Array obj) { c10::impl::GenericList list(c10::AnyType::get()); for (auto entry : obj) { - list.push_back(from_ruby<torch::IValue>(entry)); + list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value())); } return torch::IValue(list); }) - .define_singleton_method( + .define_singleton_function( "from_string", - *[](Rice::String v) { + [](Rice::String v) { return torch::IValue(v.str()); }) - .define_singleton_method( + .define_singleton_function( "from_int", - *[](int64_t v) { + [](int64_t v) { return torch::IValue(v); }) - .define_singleton_method( + .define_singleton_function( "from_double", - *[](double v) { + [](double v) { return torch::IValue(v); }) - .define_singleton_method( + .define_singleton_function( "from_bool", - *[](bool v) { + [](bool v) { return torch::IValue(v); }) // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h // createGenericDict and toIValue - .define_singleton_method( + .define_singleton_function( "from_dict", - *[](Rice::Hash obj) { + [](Rice::Hash obj) { auto key_type = c10::AnyType::get(); auto value_type = c10::AnyType::get(); c10::impl::GenericDict elems(key_type, value_type); elems.reserve(obj.size()); for (auto entry : obj) { - elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second)); + elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second)); } return torch::IValue(std::move(elems)); }); }