lib/torch.rb in torch-rb-0.5.3 vs lib/torch.rb in torch-rb-0.6.0
- old
+ new
@@ -335,28 +335,27 @@
float32: Numo::SFloat,
float64: Numo::DFloat
}
end
- def no_grad
- previous_value = grad_enabled?
- begin
- _set_grad_enabled(false)
- yield
- ensure
- _set_grad_enabled(previous_value)
- end
+ def no_grad(&block)
+ grad_enabled(false, &block)
end
- def enable_grad
+ def enable_grad(&block)
+ grad_enabled(true, &block)
+ end
+
+ def grad_enabled(value)
previous_value = grad_enabled?
begin
- _set_grad_enabled(true)
+ _set_grad_enabled(value)
yield
ensure
_set_grad_enabled(previous_value)
end
end
+ alias_method :set_grad_enabled, :grad_enabled
def device(str)
Device.new(str)
end