lib/torch.rb in torch-rb-0.3.1 vs lib/torch.rb in torch-rb-0.3.2
- old
+ new
@@ -177,12 +177,14 @@
# nn other
require "torch/nn/functional"
require "torch/nn/init"
# utils
+require "torch/utils/data"
require "torch/utils/data/data_loader"
require "torch/utils/data/dataset"
+require "torch/utils/data/subset"
require "torch/utils/data/tensor_dataset"
# hub
require "torch/hub"
@@ -308,9 +310,19 @@
def no_grad
previous_value = grad_enabled?
begin
_set_grad_enabled(false)
+ yield
+ ensure
+ _set_grad_enabled(previous_value)
+ end
+ end
+
+ def enable_grad
+ previous_value = grad_enabled?
+ begin
+ _set_grad_enabled(true)
yield
ensure
_set_grad_enabled(previous_value)
end
end