-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
I encountered a variety of issues while trying to adopt a combination of DistributedDataParallel and DTensor based tensor parallelism. Some specific to DDP+TP, some more general.
This seems to be somewhat known since e.g. torchtitan disallows combining DDP with other parallelization strategies https://github.com/pytorch/torchtitan/blob/c08c9d4962ea843dd786b850d1716861955b9a9f/torchtitan/models/llama3/infra/parallelize.py#L123 but I couldn't find any explicit issue tracking it, hence I put my findings here.
- The logic designed to support tensor parallelism within DDP does not operate as documented https://github.com/pytorch/pytorch/commits/v2.7.0/torch/distributed/tensor/parallel/ddp.py . In principle it should:
A) Convert from DTensor to local tensor at init time.
B) Convert from local tensor to DTensor in pre-forward.
C) Convert from DTensor to local tensor in post-forward.
In practice, once (B) fires for the first time, all parameters are replaced with non-leaf DTensor bydef _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): for name, param in module.named_parameters():
As a consequence of this, the model appears to be stateless when invokingstate_dict()
, since all parameters have been deleted, so neither standard state APIs nor DCP work.
Even if the (C) was called correctly, it's not clear to me that allocating a new parameter on every forward would make sense w.r.t. to the optimizer. Doing a swap_tensors
is also not possible (last time I checked) since torch.Tensor and DTensor have different slots, so I think the original parameters created in (A) need to be preserved. On top of that, you would have that this tensor hook would be called multiple times
tensor.register_hook( |
Note that while I'm talking specifically about
DistributedDataParallel
, replicate
uses the same conversion hooks to compose with TP # TODO: This is a temporary work around to enable DDP + TP. |
-
Composition with activation checkpointing is also broken since the two hooks are called once per-model forward, rather than once per-submodule forward. So if activation checkpointing invokes a submodule, DTensor conversion doesn't happen.
Additionally, even in the scenario where hooks operated per submodule, early stopping being enabled by default is also problematicpytorch/torch/utils/checkpoint.py
Line 734 in 1341794
_enable_checkpoint_early_stop = True -
Propagation of requires_grad is fairly inconsistent. It's not propagated at init time here
t = nn.Parameter(t) dist_param = nn.Parameter( tensor._local_tensor.requires_grad_()
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k