lib/torch.rb in torch-rb-0.5.3 vs lib/torch.rb in torch-rb-0.6.0

- old
+ new

@@ -335,28 +335,27 @@ float32: Numo::SFloat, float64: Numo::DFloat } end - def no_grad - previous_value = grad_enabled? - begin - _set_grad_enabled(false) - yield - ensure - _set_grad_enabled(previous_value) - end + def no_grad(&block) + grad_enabled(false, &block) end - def enable_grad + def enable_grad(&block) + grad_enabled(true, &block) + end + + def grad_enabled(value) previous_value = grad_enabled? begin - _set_grad_enabled(true) + _set_grad_enabled(value) yield ensure _set_grad_enabled(previous_value) end end + alias_method :set_grad_enabled, :grad_enabled def device(str) Device.new(str) end