ext/torch/templates.hpp in torch-rb-0.3.2 vs ext/torch/templates.hpp in torch-rb-0.3.3
- old
+ new
@@ -7,10 +7,14 @@
#include <rice/Array.hpp>
#include <rice/Object.hpp>
using namespace Rice;
+using torch::Device;
+using torch::ScalarType;
+using torch::Tensor;
+
// need to wrap torch::IntArrayRef() since
// it doesn't own underlying data
class IntArrayRef {
std::vector<int64_t> vec;
public:
@@ -172,12 +176,10 @@
MyReduction from_ruby<MyReduction>(Object x)
{
return MyReduction(x);
}
-typedef torch::Tensor Tensor;
-
class OptionalTensor {
Object value;
public:
OptionalTensor(Object o) {
value = o;
@@ -195,49 +197,30 @@
OptionalTensor from_ruby<OptionalTensor>(Object x)
{
return OptionalTensor(x);
}
-class ScalarType {
- Object value;
- public:
- ScalarType(Object o) {
- value = o;
- }
- operator at::ScalarType() {
- throw std::runtime_error("ScalarType arguments not implemented yet");
- }
-};
-
template<>
inline
-ScalarType from_ruby<ScalarType>(Object x)
+torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
{
- return ScalarType(x);
+ if (x.is_nil()) {
+ return torch::nullopt;
+ } else {
+ return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
+ }
}
-class OptionalScalarType {
- Object value;
- public:
- OptionalScalarType(Object o) {
- value = o;
- }
- operator c10::optional<at::ScalarType>() {
- if (value.is_nil()) {
- return c10::nullopt;
- }
- return ScalarType(value);
- }
-};
-
template<>
inline
-OptionalScalarType from_ruby<OptionalScalarType>(Object x)
+torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
{
- return OptionalScalarType(x);
+ if (x.is_nil()) {
+ return torch::nullopt;
+ } else {
+ return torch::optional<int64_t>{from_ruby<int64_t>(x)};
+ }
}
-
-typedef torch::Device Device;
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);