lib/torch.rb in torch-rb-0.3.4 vs lib/torch.rb in torch-rb-0.3.5
- old
+ new
@@ -386,9 +386,13 @@
def randn(*size, **options)
_randn(tensor_size(size), tensor_options(**options))
end
def randperm(n, **options)
+ # dtype hack in Python
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
+ options[:dtype] ||= :int64
+
_randperm(n, tensor_options(**options))
end
def zeros(*size, **options)
_zeros(tensor_size(size), tensor_options(**options))