Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use getattr for .optim and allow variable num of args #12184

Merged
merged 3 commits into from
Feb 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions nemo/lightning/pytorch/callbacks/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@


def collect_precision(tensor: torch.Tensor) -> Dict[str, str]:
maanug-nv marked this conversation as resolved.
Show resolved Hide resolved
return {"Precision": str(tensor.dtype)}
"""Returns tensor's precision"""
if isinstance(tensor, torch.Tensor):
return {"Precision": str(tensor.dtype)}
else:
return {"Precision": "not-a-tensor"}


def collect_precision_and_shape(tensor: torch.Tensor) -> Dict[str, str]:
return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)}
"""Returns tensor's shape & precision"""
if isinstance(tensor, torch.Tensor):
return {"Shape": str(tensor.shape), "Precision": str(tensor.dtype)}
else:
return {"Shape": "not-a-tensor", "Precision": "not-a-tensor"}


class ParameterDebugger(Callback):
Expand Down Expand Up @@ -106,20 +114,20 @@ def __init__(
if isinstance(log_on_hooks, str):
log_on_hooks = [log_on_hooks]
for hook_name in log_on_hooks:
assert (
hook_name in valid_hooks
), f"Hook {hook_name} supplied to log_on_hooks is not valid or can not be used. Valid hooks are {valid_hooks}"
assert hook_name in valid_hooks, (
"Hook {} supplied to log_on_hooks is not valid or " "can not be used. Valid hooks are {}"
).format(hook_name, valid_hooks)
setattr(self, hook_name, self._apply_user_funcs)

def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def _apply_user_funcs(self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs) -> None:
"""
Iterate over model parameters, find gradient tensor, apply and collect outputs of
param_fn and grad_fn, and log outputs in a table.
"""

def find_grad_tensor(param: torch.Tensor) -> Optional[torch.Tensor]:
"""If using MCore optimizer, search the grad buckets for param's grad tensor."""
if not isinstance(pl_module.optim, MegatronOptimizerModule):
if not isinstance(getattr(pl_module, 'optim', None), MegatronOptimizerModule):
return param.grad

for buf in pl_module.buffers:
Expand Down
Loading