ext/torch/ext.cpp in torch-rb-0.1.2 vs ext/torch/ext.cpp in torch-rb-0.1.3

- old
+ new

@@ -86,12 +86,54 @@ { return IntArrayRef(x); } // for now -typedef float Scalar; +class Scalar { + torch::Scalar value; + public: + Scalar(Object o) { + // TODO cast based on Ruby type + if (o.rb_type() == T_FIXNUM) { + value = torch::Scalar(from_ruby<int64_t>(o)); + } else { + value = torch::Scalar(from_ruby<float>(o)); + } + } + operator torch::Scalar() { + return value; + } +}; +template<> +inline +Scalar from_ruby<Scalar>(Object x) +{ + return Scalar(x); +} + +class TensorList { + std::vector<torch::Tensor> vec; + public: + TensorList(Object o) { + Array a = Array(o); + for (size_t i = 0; i < a.size(); i++) { + vec.push_back(from_ruby<torch::Tensor>(a[i])); + } + } + operator torch::TensorList() { + return torch::TensorList(vec); + } +}; + +template<> +inline +TensorList from_ruby<TensorList>(Object x) +{ + return TensorList(x); +} + extern "C" void Init_ext() { Module rb_mTorch = define_module("Torch") .define_singleton_method( @@ -205,10 +247,15 @@ "_argmax_dim", *[](torch::Tensor& input, int64_t dim, bool keepdim) { return torch::argmax(input, dim, keepdim); }) .define_singleton_method( + "_cat", + *[](TensorList tensors, int64_t dim) { + return torch::cat(tensors, dim); + }) + .define_singleton_method( "_norm", *[](torch::Tensor& input) { return torch::norm(input); }) .define_singleton_method( @@ -220,20 +267,36 @@ "_max", *[](torch::Tensor& input) { return torch::max(input); }) .define_singleton_method( + "_max_out", + *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) { + // TODO add return value + torch::_max_out(max, max_indices, input, dim, keepdim); + }) + .define_singleton_method( + "_sqrt", + *[](torch::Tensor& input) { + return torch::sqrt(input); + }) + .define_singleton_method( "_exp", *[](torch::Tensor& input) { return torch::exp(input); }) .define_singleton_method( "_log", *[](torch::Tensor& input) { return torch::log(input); }) .define_singleton_method( + "_sign", + *[](torch::Tensor& input) { + return torch::sign(input); + }) + .define_singleton_method( "_unsqueeze", *[](torch::Tensor& input, int64_t dim) { return torch::unsqueeze(input, dim); }) .define_singleton_method( @@ -250,17 +313,29 @@ "_eq", *[](torch::Tensor& input, torch::Tensor& other) { return torch::eq(input, other); }) .define_singleton_method( + "_gt", + // TODO support tensors + *[](torch::Tensor& input, Scalar other) { + return torch::gt(input, other); + }) + .define_singleton_method( + "_lt", + // TODO support tensors + *[](torch::Tensor& input, Scalar other) { + return torch::lt(input, other); + }) + .define_singleton_method( "_add", *[](torch::Tensor& input, torch::Tensor& other) { return torch::add(input, other); }) .define_singleton_method( "_add_scalar", - *[](torch::Tensor& input, float other) { + *[](torch::Tensor& input, Scalar other) { return torch::add(input, other); }) .define_singleton_method( "_add_out", *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) { @@ -271,77 +346,77 @@ *[](torch::Tensor& input, torch::Tensor& other) { return torch::sub(input, other); }) .define_singleton_method( "_sub_scalar", - *[](torch::Tensor& input, float other) { + *[](torch::Tensor& input, Scalar other) { return torch::sub(input, other); }) .define_singleton_method( "_mul", *[](torch::Tensor& input, torch::Tensor& other) { return torch::mul(input, other); }) .define_singleton_method( "_mul_scalar", - *[](torch::Tensor& input, float other) { + *[](torch::Tensor& input, Scalar other) { return torch::mul(input, other); }) .define_singleton_method( "_div", *[](torch::Tensor& input, torch::Tensor& other) { return torch::div(input, other); }) .define_singleton_method( "_div_scalar", - *[](torch::Tensor& input, float other) { + *[](torch::Tensor& input, Scalar other) { return torch::div(input, other); }) .define_singleton_method( "_remainder", *[](torch::Tensor& input, torch::Tensor& other) { return torch::remainder(input, other); }) .define_singleton_method( "_remainder_scalar", - *[](torch::Tensor& input, float other) { + *[](torch::Tensor& input, Scalar other) { return torch::remainder(input, other); }) .define_singleton_method( "_pow", *[](torch::Tensor& input, Scalar exponent) { return torch::pow(input, exponent); }) .define_singleton_method( + "_abs", + *[](torch::Tensor& input) { + return torch::abs(input); + }) + .define_singleton_method( "_neg", *[](torch::Tensor& input) { return torch::neg(input); }) .define_singleton_method( "_reshape", *[](torch::Tensor& input, IntArrayRef shape) { return torch::reshape(input, shape); }) .define_singleton_method( + "_flatten", + *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) { + return torch::flatten(input, start_dim, end_dim); + }) + .define_singleton_method( "relu", *[](torch::Tensor& input) { return torch::relu(input); }) .define_singleton_method( - "prelu", - *[](torch::Tensor& input, torch::Tensor& weight) { - return torch::prelu(input, weight); - }) - .define_singleton_method( - "leaky_relu", - *[](torch::Tensor& input, Scalar negative_slope = 0.01) { - return torch::leaky_relu(input, negative_slope); - }) - .define_singleton_method( "conv2d", - *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) { - return torch::conv2d(input, weight, bias, stride, padding); + *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { + return torch::conv2d(input, weight, bias, stride, padding, dilation, groups); }) .define_singleton_method( "linear", *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) { return torch::linear(input, weight, bias); @@ -355,21 +430,75 @@ "avg_pool2d", *[](torch::Tensor& input, IntArrayRef kernel_size) { return torch::avg_pool2d(input, kernel_size); }) .define_singleton_method( + "_dropout", + *[](torch::Tensor& input, float p, bool train) { + return torch::dropout(input, p, train); + }) + .define_singleton_method( + "_dropout!", + *[](torch::Tensor& input, float p, bool train) { + return torch::dropout_(input, p, train); + }) + .define_singleton_method( + "_feature_dropout", + *[](torch::Tensor& input, float p, bool train) { + return torch::feature_dropout(input, p, train); + }) + .define_singleton_method( + "_feature_dropout!", + *[](torch::Tensor& input, float p, bool train) { + return torch::feature_dropout_(input, p, train); + }) + .define_singleton_method( + "_alpha_dropout", + *[](torch::Tensor& input, float p, bool train) { + return torch::alpha_dropout(input, p, train); + }) + .define_singleton_method( + "_alpha_dropout!", + *[](torch::Tensor& input, float p, bool train) { + return torch::alpha_dropout_(input, p, train); + }) + .define_singleton_method( + "_feature_alpha_dropout", + *[](torch::Tensor& input, float p, bool train) { + return torch::feature_alpha_dropout(input, p, train); + }) + .define_singleton_method( + "_feature_alpha_dropout!", + *[](torch::Tensor& input, float p, bool train) { + return torch::feature_alpha_dropout_(input, p, train); + }) + .define_singleton_method( + "_embedding", + // weight and indices are swapped from Python interface + *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); + }) + .define_singleton_method( "mse_loss", *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) { auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum; return torch::mse_loss(input, target, red); }) .define_singleton_method( "nll_loss", - *[](torch::Tensor& input, torch::Tensor& target) { - return torch::nll_loss(input, target); + *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) { + auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum; + return torch::nll_loss(input, target, {}, red); }) + .define_singleton_method("numel", &torch::numel) .define_singleton_method( + "_from_blob", + *[](String s, IntArrayRef size, const torch::TensorOptions &options) { + void *data = const_cast<char *>(s.c_str()); + return torch::from_blob(data, size, options); + }) + .define_singleton_method( "_tensor", *[](Object o, IntArrayRef size, const torch::TensorOptions &options) { Array a = Array(o); std::vector<float> vec; for (size_t i = 0; i < a.size(); i++) { @@ -385,14 +514,24 @@ .define_method("floating_point?", &torch::Tensor::is_floating_point) .define_method("signed?", &torch::Tensor::is_signed) .define_method("sparse?", &torch::Tensor::is_sparse) .define_method("quantized?", &torch::Tensor::is_quantized) .define_method("dim", &torch::Tensor::dim) - .define_method("numel", &torch::Tensor::numel) .define_method("element_size", &torch::Tensor::element_size) .define_method("requires_grad", &torch::Tensor::requires_grad) + .define_method("view_as", &torch::Tensor::view_as) .define_method( + "addcmul!", + *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) { + return self.addcmul_(tensor1, tensor2, value); + }) + .define_method( + "addcdiv!", + *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) { + return self.addcdiv_(tensor1, tensor2, value); + }) + .define_method( "zero!", *[](torch::Tensor& self) { return self.zero_(); }) .define_method( @@ -458,30 +597,80 @@ "_view", *[](torch::Tensor& self, IntArrayRef size) { return self.view(size); }) .define_method( - "add!", + "resize_as!", *[](torch::Tensor& self, torch::Tensor& other) { - self.add_(other); + return self.resize_as_(other); }) .define_method( + "fill!", + *[](torch::Tensor& self, Scalar value) { + return self.fill_(value); + }) + .define_method( + "_add!", + *[](torch::Tensor& self, torch::Tensor& other) { + return self.add_(other); + }) + .define_method( + "_add_alpha!", + *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) { + return self.add_(other, alpha); + }) + .define_method( + "_add_scalar!", + *[](torch::Tensor& self, Scalar other) { + return self.add_(other); + }) + .define_method( + "normal!", + *[](torch::Tensor& self, double mean, double std) { + return self.normal_(mean, std); + }) + .define_method( "sub!", *[](torch::Tensor& self, torch::Tensor& other) { - self.sub_(other); + return self.sub_(other); }) .define_method( - "mul!", + "_mul!", *[](torch::Tensor& self, torch::Tensor& other) { - self.mul_(other); + return self.mul_(other); }) .define_method( + "_mul_scalar!", + *[](torch::Tensor& self, Scalar other) { + return self.mul_(other); + }) + .define_method( "div!", *[](torch::Tensor& self, torch::Tensor& other) { - self.div_(other); + return self.div_(other); }) .define_method( + "sqrt!", + *[](torch::Tensor& self) { + return self.sqrt_(); + }) + .define_method( + "unsqueeze!", + *[](torch::Tensor& self, int64_t dim) { + return self.unsqueeze_(dim); + }) + .define_method( + "copy!", + *[](torch::Tensor& self, torch::Tensor& src) { + return self.copy_(src); + }) + .define_method( + "clone", + *[](torch::Tensor& self) { + return self.clone(); + }) + .define_method( "log_softmax", *[](torch::Tensor& self, int64_t dim) { return self.log_softmax(dim); }) .define_method( @@ -530,22 +719,29 @@ double* data = self.data_ptr<double>(); for (int i = 0; i < self.numel(); i++) { a.push(data[i]); } } else if (dtype == torch::kBool) { - // bool - throw std::runtime_error("Type not supported yet"); + bool* data = self.data_ptr<bool>(); + for (int i = 0; i < self.numel(); i++) { + a.push(data[i] ? True : False); + } } else { throw std::runtime_error("Unsupported type"); } return a; }) .define_method( "_size", *[](torch::Tensor& self, int i) { return self.size(i); }) + .define_method( + "_to", + *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) { + return self.to(device, (torch::ScalarType) dtype, non_blocking, copy); + }) .define_singleton_method( "_make_subclass", *[](torch::Tensor& rd, bool requires_grad) { auto data = torch::autograd::as_variable_ref(rd).detach(); data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true); @@ -595,16 +791,21 @@ Module rb_mNN = define_module_under(rb_mTorch, "NN"); Module rb_mInit = define_module_under(rb_mNN, "Init") .define_singleton_method( - "kaiming_uniform_", + "kaiming_uniform!", *[](torch::Tensor& input, double a) { return torch::nn::init::kaiming_uniform_(input, a); }) .define_singleton_method( - "uniform_", + "normal!", + *[](torch::Tensor& input) { + return torch::nn::init::normal_(input); + }) + .define_singleton_method( + "uniform!", *[](torch::Tensor& input, double to, double from) { return torch::nn::init::uniform_(input, to, from); }); Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter") @@ -617,6 +818,22 @@ .define_method( "_grad_defined", *[](torch::autograd::Variable& self) { return self.grad().defined(); }); + + Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device") + .define_constructor(Constructor<torch::Device, std::string>()) + .define_method("index", &torch::Device::index) + .define_method("index?", &torch::Device::has_index) + .define_method( + "type", + *[](torch::Device& self) { + std::stringstream s; + s << self.type(); + return s.str(); + }); + + Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA") + .define_singleton_method("available?", &torch::cuda::is_available) + .define_singleton_method("device_count", &torch::cuda::device_count); }