Skip to content

Instantly share code, notes, and snippets.

@jxmorris12
Last active April 19, 2024 15:54
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save jxmorris12/58659345e670613bdd7408b31ac6df78 to your computer and use it in GitHub Desktop.
Save jxmorris12/58659345e670613bdd7408b31ac6df78 to your computer and use it in GitHub Desktop.
verify parameter weights & gradients in pytorch
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None:
if hasattr(model, "module"):
model = model.module
world_size = get_world_size()
for name, param in model.named_parameters():
gathered_param = gather(param).reshape((world_size, -1))
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs()
rank_params_eq = (absolute_diffs < atol).all()
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}"
###################################################################################################################
gathered_param_grad = gather(param.grad).reshape((world_size, -1))
absolute_grad_diffs = (gathered_param_grad[None, 0, :] - gathered_param_grad).abs()
rank_grad_params_eq = (absolute_grad_diffs < atol).all()
assert rank_grad_params_eq, f"❌ param [{name}] grad not equal - got max_absolute_diff={absolute_grad_diffs.max()}"
###################################################################################################################
print0("Verified DDP parameter correctness ✅")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment