ext/torch/ext.cpp in torch-rb-0.2.1 vs ext/torch/ext.cpp in torch-rb-0.2.2

- old
+ new

@@ -81,10 +81,20 @@ "to_int", *[](torch::IValue& self) { return self.toInt(); }) .define_method( + "to_list", + *[](torch::IValue& self) { + auto list = self.toListRef(); + Array obj; + for (auto& elem : list) { + obj.push(to_ruby<torch::IValue>(torch::IValue{elem})); + } + return obj; + }) + .define_method( "to_string_ref", *[](torch::IValue& self) { return self.toStringRef(); }) .define_method( @@ -94,19 +104,29 @@ }) .define_method( "to_generic_dict", *[](torch::IValue& self) { auto dict = self.toGenericDict(); - Hash h; + Hash obj; for (auto& pair : dict) { - h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()}); + obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()}); } - return h; + return obj; }) .define_singleton_method( "from_tensor", *[](torch::Tensor& v) { return torch::IValue(v); + }) + // TODO create specialized list types? + .define_singleton_method( + "from_list", + *[](Array obj) { + c10::impl::GenericList list(c10::AnyType::get()); + for (auto entry : obj) { + list.push_back(from_ruby<torch::IValue>(entry)); + } + return torch::IValue(list); }) .define_singleton_method( "from_string", *[](String v) { return torch::IValue(v.str());