lib/torch.rb in torch-rb-0.3.1 vs lib/torch.rb in torch-rb-0.3.2

- old
+ new

@@ -177,12 +177,14 @@ # nn other require "torch/nn/functional" require "torch/nn/init" # utils +require "torch/utils/data" require "torch/utils/data/data_loader" require "torch/utils/data/dataset" +require "torch/utils/data/subset" require "torch/utils/data/tensor_dataset" # hub require "torch/hub" @@ -308,9 +310,19 @@ def no_grad previous_value = grad_enabled? begin _set_grad_enabled(false) + yield + ensure + _set_grad_enabled(previous_value) + end + end + + def enable_grad + previous_value = grad_enabled? + begin + _set_grad_enabled(true) yield ensure _set_grad_enabled(previous_value) end end