Skip to content

Commit

Permalink
Fix the retention for existing gradients in the Grad Acc API (#8658)
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws authored Jan 31, 2025
1 parent e583c2c commit 39dd795
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions torch_xla/experimental/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ def gradient_accumulation(
Notes:
The model tracing will happen entirely within the loop. Hence, it is
* The model tracing will happen entirely within the loop. Hence, it is
assumed that `train_step` is purposefully encapsulated inside of the
loop. Hence, it is not recommended to have any operation involving the
model parameters outside of `train_step`.
* Note that zeroing the gradients to zero instead of None, (e.g.
`.zero_grad(set_to_none=False)) will avoid the device transfer of the
initial gradients in every call.
Args:
train_step: Training function that takes iterable tensors and carried
Expand Down Expand Up @@ -380,7 +383,7 @@ def body_fn(iteri: torch.Tensor, _: torch.Tensor,
for param in model_parameters:
if not param.requires_grad:
continue
if param.grad:
if param.grad is not None:
grad = param.grad
else:
grad = torch.zeros(param.size()).to(param.device).requires_grad_(False)
Expand Down

0 comments on commit 39dd795

Please sign in to comment.