ext/torch/ext.cpp in torch-rb-0.2.0 vs ext/torch/ext.cpp in torch-rb-0.2.1
- old
+ new
@@ -3,10 +3,11 @@
#include <torch/torch.h>
#include <rice/Array.hpp>
#include <rice/Class.hpp>
#include <rice/Constructor.hpp>
+#include <rice/Hash.hpp>
#include "templates.hpp"
// generated with:
// rake generate:functions
@@ -20,10 +21,15 @@
class Parameter: public torch::autograd::Variable {
public:
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
};
+void handle_error(c10::Error const & ex)
+{
+ throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
+}
+
extern "C"
void Init_ext()
{
Module rb_mTorch = define_module("Torch");
add_torch_functions(rb_mTorch);
@@ -32,10 +38,112 @@
add_tensor_functions(rb_cTensor);
Module rb_mNN = define_module_under(rb_mTorch, "NN");
add_nn_functions(rb_mNN);
+ // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
+ Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
+ .define_constructor(Constructor<torch::IValue>())
+ .define_method("bool?", &torch::IValue::isBool)
+ .define_method("bool_list?", &torch::IValue::isBoolList)
+ .define_method("capsule?", &torch::IValue::isCapsule)
+ .define_method("custom_class?", &torch::IValue::isCustomClass)
+ .define_method("device?", &torch::IValue::isDevice)
+ .define_method("double?", &torch::IValue::isDouble)
+ .define_method("double_list?", &torch::IValue::isDoubleList)
+ .define_method("future?", &torch::IValue::isFuture)
+ // .define_method("generator?", &torch::IValue::isGenerator)
+ .define_method("generic_dict?", &torch::IValue::isGenericDict)
+ .define_method("list?", &torch::IValue::isList)
+ .define_method("int?", &torch::IValue::isInt)
+ .define_method("int_list?", &torch::IValue::isIntList)
+ .define_method("module?", &torch::IValue::isModule)
+ .define_method("none?", &torch::IValue::isNone)
+ .define_method("object?", &torch::IValue::isObject)
+ .define_method("ptr_type?", &torch::IValue::isPtrType)
+ .define_method("py_object?", &torch::IValue::isPyObject)
+ .define_method("r_ref?", &torch::IValue::isRRef)
+ .define_method("scalar?", &torch::IValue::isScalar)
+ .define_method("string?", &torch::IValue::isString)
+ .define_method("tensor?", &torch::IValue::isTensor)
+ .define_method("tensor_list?", &torch::IValue::isTensorList)
+ .define_method("tuple?", &torch::IValue::isTuple)
+ .define_method(
+ "to_bool",
+ *[](torch::IValue& self) {
+ return self.toBool();
+ })
+ .define_method(
+ "to_double",
+ *[](torch::IValue& self) {
+ return self.toDouble();
+ })
+ .define_method(
+ "to_int",
+ *[](torch::IValue& self) {
+ return self.toInt();
+ })
+ .define_method(
+ "to_string_ref",
+ *[](torch::IValue& self) {
+ return self.toStringRef();
+ })
+ .define_method(
+ "to_tensor",
+ *[](torch::IValue& self) {
+ return self.toTensor();
+ })
+ .define_method(
+ "to_generic_dict",
+ *[](torch::IValue& self) {
+ auto dict = self.toGenericDict();
+ Hash h;
+ for (auto& pair : dict) {
+ h[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
+ }
+ return h;
+ })
+ .define_singleton_method(
+ "from_tensor",
+ *[](torch::Tensor& v) {
+ return torch::IValue(v);
+ })
+ .define_singleton_method(
+ "from_string",
+ *[](String v) {
+ return torch::IValue(v.str());
+ })
+ .define_singleton_method(
+ "from_int",
+ *[](int64_t v) {
+ return torch::IValue(v);
+ })
+ .define_singleton_method(
+ "from_double",
+ *[](double v) {
+ return torch::IValue(v);
+ })
+ .define_singleton_method(
+ "from_bool",
+ *[](bool v) {
+ return torch::IValue(v);
+ })
+ // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
+ // createGenericDict and toIValue
+ .define_singleton_method(
+ "from_dict",
+ *[](Hash obj) {
+ auto key_type = c10::AnyType::get();
+ auto value_type = c10::AnyType::get();
+ c10::impl::GenericDict elems(key_type, value_type);
+ elems.reserve(obj.size());
+ for (auto entry : obj) {
+ elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
+ }
+ return torch::IValue(std::move(elems));
+ });
+
rb_mTorch.define_singleton_method(
"grad_enabled?",
*[]() {
return torch::GradMode::is_enabled();
})
@@ -111,16 +219,24 @@
return torch::zeros(size, options);
})
// begin operations
.define_singleton_method(
"_save",
- *[](const Tensor &value) {
+ *[](const torch::IValue &value) {
auto v = torch::pickle_save(value);
std::string str(v.begin(), v.end());
return str;
})
.define_singleton_method(
+ "_load",
+ *[](const std::string &s) {
+ std::vector<char> v;
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
+ return torch::pickle_load(v);
+ })
+ .define_singleton_method(
"_binary_cross_entropy_with_logits",
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
})
.define_singleton_method(
@@ -155,10 +271,11 @@
}
return t.reshape(size);
});
rb_cTensor
+ .add_handler<c10::Error>(handle_error)
.define_method("cuda?", &torch::Tensor::is_cuda)
.define_method("sparse?", &torch::Tensor::is_sparse)
.define_method("quantized?", &torch::Tensor::is_quantized)
.define_method("dim", &torch::Tensor::dim)
.define_method("numel", &torch::Tensor::numel)
@@ -286,10 +403,11 @@
auto var = data.set_requires_grad(requires_grad);
return Parameter(std::move(var));
});
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
+ .add_handler<c10::Error>(handle_error)
.define_constructor(Constructor<torch::TensorOptions>())
.define_method(
"dtype",
*[](torch::TensorOptions& self, int dtype) {
return self.dtype((torch::ScalarType) dtype);
@@ -309,16 +427,11 @@
return self.layout(l);
})
.define_method(
"device",
*[](torch::TensorOptions& self, std::string device) {
- try {
- // needed to catch exception
- torch::Device d(device);
- return self.device(d);
- } catch (const c10::Error& error) {
- throw std::runtime_error(error.what_without_backtrace());
- }
+ torch::Device d(device);
+ return self.device(d);
})
.define_method(
"requires_grad",
*[](torch::TensorOptions& self, bool requires_grad) {
return self.requires_grad(requires_grad);