ext/torch/ext.cpp in torch-rb-0.1.1 vs ext/torch/ext.cpp in torch-rb-0.1.2
- old
+ new
@@ -195,10 +195,20 @@
"_sum_dim",
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
return torch::sum(input, dim, keepdim);
})
.define_singleton_method(
+ "_argmax",
+ *[](torch::Tensor& input) {
+ return torch::argmax(input);
+ })
+ .define_singleton_method(
+ "_argmax_dim",
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
+ return torch::argmax(input, dim, keepdim);
+ })
+ .define_singleton_method(
"_norm",
*[](torch::Tensor& input) {
return torch::norm(input);
})
.define_singleton_method(
@@ -235,10 +245,15 @@
"_matmul",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::matmul(input, other);
})
.define_singleton_method(
+ "_eq",
+ *[](torch::Tensor& input, torch::Tensor& other) {
+ return torch::eq(input, other);
+ })
+ .define_singleton_method(
"_add",
*[](torch::Tensor& input, torch::Tensor& other) {
return torch::add(input, other);
})
.define_singleton_method(
@@ -300,18 +315,33 @@
"_neg",
*[](torch::Tensor& input) {
return torch::neg(input);
})
.define_singleton_method(
+ "_reshape",
+ *[](torch::Tensor& input, IntArrayRef shape) {
+ return torch::reshape(input, shape);
+ })
+ .define_singleton_method(
"relu",
*[](torch::Tensor& input) {
return torch::relu(input);
})
.define_singleton_method(
+ "prelu",
+ *[](torch::Tensor& input, torch::Tensor& weight) {
+ return torch::prelu(input, weight);
+ })
+ .define_singleton_method(
+ "leaky_relu",
+ *[](torch::Tensor& input, Scalar negative_slope = 0.01) {
+ return torch::leaky_relu(input, negative_slope);
+ })
+ .define_singleton_method(
"conv2d",
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
- return torch::conv2d(input, weight, bias);
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
+ return torch::conv2d(input, weight, bias, stride, padding);
})
.define_singleton_method(
"linear",
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
return torch::linear(input, weight, bias);
@@ -320,10 +350,15 @@
"max_pool2d",
*[](torch::Tensor& input, IntArrayRef kernel_size) {
return torch::max_pool2d(input, kernel_size);
})
.define_singleton_method(
+ "avg_pool2d",
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
+ return torch::avg_pool2d(input, kernel_size);
+ })
+ .define_singleton_method(
"mse_loss",
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
return torch::mse_loss(input, target, red);
})
@@ -364,35 +399,50 @@
"detach!",
*[](torch::Tensor& self) {
return self.detach_();
})
.define_method(
- "_access",
- *[](torch::Tensor& self, int64_t index) {
- return self[index];
+ "_select",
+ *[](torch::Tensor& self, int64_t dim, int64_t index) {
+ return self.select(dim, index);
})
.define_method(
+ "_slice",
+ *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
+ return self.slice(dim, start, end, step);
+ })
+ .define_method(
"_requires_grad!",
*[](torch::Tensor& self, bool requires_grad) {
return self.set_requires_grad(requires_grad);
})
.define_method(
- "backward",
+ "_backward",
*[](torch::Tensor& self) {
return self.backward();
})
.define_method(
+ "_backward_gradient",
+ *[](torch::Tensor& self, const torch::Tensor& gradient) {
+ return self.backward(gradient);
+ })
+ .define_method(
"grad",
*[](torch::Tensor& self) {
return self.grad();
})
.define_method(
"_dtype",
*[](torch::Tensor& self) {
return (int) at::typeMetaToScalarType(self.dtype());
})
.define_method(
+ "_type",
+ *[](torch::Tensor& self, int dtype) {
+ return self.toType((torch::ScalarType) dtype);
+ })
+ .define_method(
"_layout",
*[](torch::Tensor& self) {
std::stringstream s;
s << self.layout();
return s.str();
@@ -433,10 +483,15 @@
"log_softmax",
*[](torch::Tensor& self, int64_t dim) {
return self.log_softmax(dim);
})
.define_method(
+ "data",
+ *[](torch::Tensor& self) {
+ return self.data();
+ })
+ .define_method(
"_data",
*[](torch::Tensor& self) {
Array a;
auto dtype = self.dtype();
@@ -551,11 +606,17 @@
*[](torch::Tensor& input, double to, double from) {
return torch::nn::init::uniform_(input, to, from);
});
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
+ // TODO return grad or nil to remove need for 2nd function
.define_method(
- "grad",
+ "_grad",
*[](torch::autograd::Variable& self) {
return self.grad();
+ })
+ .define_method(
+ "_grad_defined",
+ *[](torch::autograd::Variable& self) {
+ return self.grad().defined();
});
}