ext/torch/ext.cpp in torch-rb-0.1.2 vs ext/torch/ext.cpp in torch-rb-0.1.3
- old
+ new
@@ -86,12 +86,54 @@
{
return IntArrayRef(x);
}
// for now
-typedef float Scalar;
+class Scalar {
+ torch::Scalar value;
+ public:
+ Scalar(Object o) {
+ // TODO cast based on Ruby type
+ if (o.rb_type() == T_FIXNUM) {
+ value = torch::Scalar(from_ruby<int64_t>(o));
+ } else {
+ value = torch::Scalar(from_ruby<float>(o));
+ }
+ }
+ operator torch::Scalar() {
+ return value;
+ }
+};
+template<>
+inline
+Scalar from_ruby<Scalar>(Object x)
+{
+ return Scalar(x);
+}
+
+class TensorList {
+ std::vector<torch::Tensor> vec;
+ public:
+ TensorList(Object o) {
+ Array a = Array(o);
+ for (size_t i = 0; i < a.size(); i++) {
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
+ }
+ }
+ operator torch::TensorList() {
+ return torch::TensorList(vec);
+ }
+};
+
+template<>
+inline
+TensorList from_ruby<TensorList>(Object x)
+{
+ return TensorList(x);
+}
+
extern "C"
void Init_ext()
{
Module rb_mTorch = define_module("Torch")
.define_singleton_method(
@@ -205,10 +247,15 @@
"_argmax_dim",
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
return torch::argmax(input, dim, keepdim);
})
.define_singleton_method(
+ "_cat",
+ *[](TensorList tensors, int64_t dim) {
+ return torch::cat(tensors, dim);
+ })
+ .define_singleton_method(
"_norm",
*[](torch::Tensor& input) {
return torch::norm(input);
})
.define_singleton_method(
@@ -220,20 +267,36 @@
"_max",
*[](torch::Tensor& input) {
return torch::max(input);
})
.define_singleton_method(
+ "_max_out",
+ *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
+ // TODO add return value
+ torch::_max_out(max, max_indices, input, dim, keepdim);
+ })
+ .define_singleton_method(
+ "_sqrt",
+ *[](torch::Tensor& input) {
+ return torch::sqrt(input);
+ })
+ .define_singleton_method(
"_exp",
*[](torch::Tensor& input) {
return torch::exp(input);
})
.define_singleton_method(
"_log",
*[](torch::Tensor& input) {
return torch::log(input);
})
.define_singleton_method(
+ "_sign",
+ *[](torch::Tensor& input) {
+ return torch::sign(input);
+ })
+ .define_singleton_method(
"_unsqueeze",
*[](torch::Tensor& input, int64_t dim) {
return torch::unsqueeze(input, dim);
})
.define_singleton_method(
@@ -250,17 +313,29 @@
"_eq",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::eq(input, other);
})
.define_singleton_method(
+ "_gt",
+ // TODO support tensors
+ *[](torch::Tensor& input, Scalar other) {
+ return torch::gt(input, other);
+ })
+ .define_singleton_method(
+ "_lt",
+ // TODO support tensors
+ *[](torch::Tensor& input, Scalar other) {
+ return torch::lt(input, other);
+ })
+ .define_singleton_method(
"_add",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::add(input, other);
})
.define_singleton_method(
"_add_scalar",
- *[](torch::Tensor& input, float other) {
+ *[](torch::Tensor& input, Scalar other) {
return torch::add(input, other);
})
.define_singleton_method(
"_add_out",
*[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
@@ -271,77 +346,77 @@
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::sub(input, other);
})
.define_singleton_method(
"_sub_scalar",
- *[](torch::Tensor& input, float other) {
+ *[](torch::Tensor& input, Scalar other) {
return torch::sub(input, other);
})
.define_singleton_method(
"_mul",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::mul(input, other);
})
.define_singleton_method(
"_mul_scalar",
- *[](torch::Tensor& input, float other) {
+ *[](torch::Tensor& input, Scalar other) {
return torch::mul(input, other);
})
.define_singleton_method(
"_div",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::div(input, other);
})
.define_singleton_method(
"_div_scalar",
- *[](torch::Tensor& input, float other) {
+ *[](torch::Tensor& input, Scalar other) {
return torch::div(input, other);
})
.define_singleton_method(
"_remainder",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::remainder(input, other);
})
.define_singleton_method(
"_remainder_scalar",
- *[](torch::Tensor& input, float other) {
+ *[](torch::Tensor& input, Scalar other) {
return torch::remainder(input, other);
})
.define_singleton_method(
"_pow",
*[](torch::Tensor& input, Scalar exponent) {
return torch::pow(input, exponent);
})
.define_singleton_method(
+ "_abs",
+ *[](torch::Tensor& input) {
+ return torch::abs(input);
+ })
+ .define_singleton_method(
"_neg",
*[](torch::Tensor& input) {
return torch::neg(input);
})
.define_singleton_method(
"_reshape",
*[](torch::Tensor& input, IntArrayRef shape) {
return torch::reshape(input, shape);
})
.define_singleton_method(
+ "_flatten",
+ *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
+ return torch::flatten(input, start_dim, end_dim);
+ })
+ .define_singleton_method(
"relu",
*[](torch::Tensor& input) {
return torch::relu(input);
})
.define_singleton_method(
- "prelu",
- *[](torch::Tensor& input, torch::Tensor& weight) {
- return torch::prelu(input, weight);
- })
- .define_singleton_method(
- "leaky_relu",
- *[](torch::Tensor& input, Scalar negative_slope = 0.01) {
- return torch::leaky_relu(input, negative_slope);
- })
- .define_singleton_method(
"conv2d",
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
- return torch::conv2d(input, weight, bias, stride, padding);
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
})
.define_singleton_method(
"linear",
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
return torch::linear(input, weight, bias);
@@ -355,21 +430,75 @@
"avg_pool2d",
*[](torch::Tensor& input, IntArrayRef kernel_size) {
return torch::avg_pool2d(input, kernel_size);
})
.define_singleton_method(
+ "_dropout",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::dropout(input, p, train);
+ })
+ .define_singleton_method(
+ "_dropout!",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::dropout_(input, p, train);
+ })
+ .define_singleton_method(
+ "_feature_dropout",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::feature_dropout(input, p, train);
+ })
+ .define_singleton_method(
+ "_feature_dropout!",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::feature_dropout_(input, p, train);
+ })
+ .define_singleton_method(
+ "_alpha_dropout",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::alpha_dropout(input, p, train);
+ })
+ .define_singleton_method(
+ "_alpha_dropout!",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::alpha_dropout_(input, p, train);
+ })
+ .define_singleton_method(
+ "_feature_alpha_dropout",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::feature_alpha_dropout(input, p, train);
+ })
+ .define_singleton_method(
+ "_feature_alpha_dropout!",
+ *[](torch::Tensor& input, float p, bool train) {
+ return torch::feature_alpha_dropout_(input, p, train);
+ })
+ .define_singleton_method(
+ "_embedding",
+ // weight and indices are swapped from Python interface
+ *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
+ return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
+ })
+ .define_singleton_method(
"mse_loss",
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
return torch::mse_loss(input, target, red);
})
.define_singleton_method(
"nll_loss",
- *[](torch::Tensor& input, torch::Tensor& target) {
- return torch::nll_loss(input, target);
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
+ return torch::nll_loss(input, target, {}, red);
})
+ .define_singleton_method("numel", &torch::numel)
.define_singleton_method(
+ "_from_blob",
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
+ void *data = const_cast<char *>(s.c_str());
+ return torch::from_blob(data, size, options);
+ })
+ .define_singleton_method(
"_tensor",
*[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
Array a = Array(o);
std::vector<float> vec;
for (size_t i = 0; i < a.size(); i++) {
@@ -385,14 +514,24 @@
.define_method("floating_point?", &torch::Tensor::is_floating_point)
.define_method("signed?", &torch::Tensor::is_signed)
.define_method("sparse?", &torch::Tensor::is_sparse)
.define_method("quantized?", &torch::Tensor::is_quantized)
.define_method("dim", &torch::Tensor::dim)
- .define_method("numel", &torch::Tensor::numel)
.define_method("element_size", &torch::Tensor::element_size)
.define_method("requires_grad", &torch::Tensor::requires_grad)
+ .define_method("view_as", &torch::Tensor::view_as)
.define_method(
+ "addcmul!",
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
+ return self.addcmul_(tensor1, tensor2, value);
+ })
+ .define_method(
+ "addcdiv!",
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
+ return self.addcdiv_(tensor1, tensor2, value);
+ })
+ .define_method(
"zero!",
*[](torch::Tensor& self) {
return self.zero_();
})
.define_method(
@@ -458,30 +597,80 @@
"_view",
*[](torch::Tensor& self, IntArrayRef size) {
return self.view(size);
})
.define_method(
- "add!",
+ "resize_as!",
*[](torch::Tensor& self, torch::Tensor& other) {
- self.add_(other);
+ return self.resize_as_(other);
})
.define_method(
+ "fill!",
+ *[](torch::Tensor& self, Scalar value) {
+ return self.fill_(value);
+ })
+ .define_method(
+ "_add!",
+ *[](torch::Tensor& self, torch::Tensor& other) {
+ return self.add_(other);
+ })
+ .define_method(
+ "_add_alpha!",
+ *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
+ return self.add_(other, alpha);
+ })
+ .define_method(
+ "_add_scalar!",
+ *[](torch::Tensor& self, Scalar other) {
+ return self.add_(other);
+ })
+ .define_method(
+ "normal!",
+ *[](torch::Tensor& self, double mean, double std) {
+ return self.normal_(mean, std);
+ })
+ .define_method(
"sub!",
*[](torch::Tensor& self, torch::Tensor& other) {
- self.sub_(other);
+ return self.sub_(other);
})
.define_method(
- "mul!",
+ "_mul!",
*[](torch::Tensor& self, torch::Tensor& other) {
- self.mul_(other);
+ return self.mul_(other);
})
.define_method(
+ "_mul_scalar!",
+ *[](torch::Tensor& self, Scalar other) {
+ return self.mul_(other);
+ })
+ .define_method(
"div!",
*[](torch::Tensor& self, torch::Tensor& other) {
- self.div_(other);
+ return self.div_(other);
})
.define_method(
+ "sqrt!",
+ *[](torch::Tensor& self) {
+ return self.sqrt_();
+ })
+ .define_method(
+ "unsqueeze!",
+ *[](torch::Tensor& self, int64_t dim) {
+ return self.unsqueeze_(dim);
+ })
+ .define_method(
+ "copy!",
+ *[](torch::Tensor& self, torch::Tensor& src) {
+ return self.copy_(src);
+ })
+ .define_method(
+ "clone",
+ *[](torch::Tensor& self) {
+ return self.clone();
+ })
+ .define_method(
"log_softmax",
*[](torch::Tensor& self, int64_t dim) {
return self.log_softmax(dim);
})
.define_method(
@@ -530,22 +719,29 @@
double* data = self.data_ptr<double>();
for (int i = 0; i < self.numel(); i++) {
a.push(data[i]);
}
} else if (dtype == torch::kBool) {
- // bool
- throw std::runtime_error("Type not supported yet");
+ bool* data = self.data_ptr<bool>();
+ for (int i = 0; i < self.numel(); i++) {
+ a.push(data[i] ? True : False);
+ }
} else {
throw std::runtime_error("Unsupported type");
}
return a;
})
.define_method(
"_size",
*[](torch::Tensor& self, int i) {
return self.size(i);
})
+ .define_method(
+ "_to",
+ *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
+ })
.define_singleton_method(
"_make_subclass",
*[](torch::Tensor& rd, bool requires_grad) {
auto data = torch::autograd::as_variable_ref(rd).detach();
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
@@ -595,16 +791,21 @@
Module rb_mNN = define_module_under(rb_mTorch, "NN");
Module rb_mInit = define_module_under(rb_mNN, "Init")
.define_singleton_method(
- "kaiming_uniform_",
+ "kaiming_uniform!",
*[](torch::Tensor& input, double a) {
return torch::nn::init::kaiming_uniform_(input, a);
})
.define_singleton_method(
- "uniform_",
+ "normal!",
+ *[](torch::Tensor& input) {
+ return torch::nn::init::normal_(input);
+ })
+ .define_singleton_method(
+ "uniform!",
*[](torch::Tensor& input, double to, double from) {
return torch::nn::init::uniform_(input, to, from);
});
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
@@ -617,6 +818,22 @@
.define_method(
"_grad_defined",
*[](torch::autograd::Variable& self) {
return self.grad().defined();
});
+
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
+ .define_constructor(Constructor<torch::Device, std::string>())
+ .define_method("index", &torch::Device::index)
+ .define_method("index?", &torch::Device::has_index)
+ .define_method(
+ "type",
+ *[](torch::Device& self) {
+ std::stringstream s;
+ s << self.type();
+ return s.str();
+ });
+
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
+ .define_singleton_method("available?", &torch::cuda::is_available)
+ .define_singleton_method("device_count", &torch::cuda::device_count);
}