ext/torch/ext.cpp in torch-rb-0.2.1 vs ext/torch/ext.cpp in torch-rb-0.2.2
- old
+ new
@@ -81,10 +81,20 @@
"to_int",
*[](torch::IValue& self) {
return self.toInt();
})
.define_method(
+ "to_list",
+ *[](torch::IValue& self) {
+ auto list = self.toListRef();
+ Array obj;
+ for (auto& elem : list) {
+ obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
+ }
+ return obj;
+ })
+ .define_method(
"to_string_ref",
*[](torch::IValue& self) {
return self.toStringRef();
})
.define_method(
@@ -94,19 +104,29 @@
})
.define_method(
"to_generic_dict",
*[](torch::IValue& self) {
auto dict = self.toGenericDict();
- Hash h;
+ Hash obj;
for (auto& pair : dict) {
- h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
+ obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
}
- return h;
+ return obj;
})
.define_singleton_method(
"from_tensor",
*[](torch::Tensor& v) {
return torch::IValue(v);
+ })
+ // TODO create specialized list types?
+ .define_singleton_method(
+ "from_list",
+ *[](Array obj) {
+ c10::impl::GenericList list(c10::AnyType::get());
+ for (auto entry : obj) {
+ list.push_back(from_ruby<torch::IValue>(entry));
+ }
+ return torch::IValue(list);
})
.define_singleton_method(
"from_string",
*[](String v) {
return torch::IValue(v.str());