ext/torch/templates.hpp in torch-rb-0.3.5 vs ext/torch/templates.hpp in torch-rb-0.3.6
- old
+ new
@@ -8,10 +8,11 @@
#include <rice/Object.hpp>
using namespace Rice;
using torch::Device;
+using torch::Scalar;
using torch::ScalarType;
using torch::Tensor;
// need to wrap torch::IntArrayRef() since
// it doesn't own underlying data
@@ -34,34 +35,10 @@
IntArrayRef from_ruby<IntArrayRef>(Object x)
{
return IntArrayRef(x);
}
-// for now
-class Scalar {
- torch::Scalar value;
- public:
- Scalar(Object o) {
- // TODO cast based on Ruby type
- if (o.rb_type() == T_FIXNUM) {
- value = torch::Scalar(from_ruby<int64_t>(o));
- } else {
- value = torch::Scalar(from_ruby<float>(o));
- }
- }
- operator torch::Scalar() {
- return value;
- }
-};
-
-template<>
-inline
-Scalar from_ruby<Scalar>(Object x)
-{
- return Scalar(x);
-}
-
class TensorList {
std::vector<torch::Tensor> vec;
public:
TensorList(Object o) {
Array a = Array(o);
@@ -192,10 +169,21 @@
}
};
template<>
inline
+Scalar from_ruby<Scalar>(Object x)
+{
+ if (x.rb_type() == T_FIXNUM) {
+ return torch::Scalar(from_ruby<int64_t>(x));
+ } else {
+ return torch::Scalar(from_ruby<double>(x));
+ }
+}
+
+template<>
+inline
OptionalTensor from_ruby<OptionalTensor>(Object x)
{
return OptionalTensor(x);
}
@@ -219,11 +207,45 @@
} else {
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
}
}
+template<>
+inline
+torch::optional<double> from_ruby<torch::optional<double>>(Object x)
+{
+ if (x.is_nil()) {
+ return torch::nullopt;
+ } else {
+ return torch::optional<double>{from_ruby<double>(x)};
+ }
+}
+
+template<>
+inline
+torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
+{
+ if (x.is_nil()) {
+ return torch::nullopt;
+ } else {
+ return torch::optional<bool>{from_ruby<bool>(x)};
+ }
+}
+
+template<>
+inline
+torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
+{
+ if (x.is_nil()) {
+ return torch::nullopt;
+ } else {
+ return torch::optional<Scalar>{from_ruby<Scalar>(x)};
+ }
+}
+
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
+Object wrap(std::vector<torch::Tensor> x);