Skip to content

DDP+TP composition does not work as expected #157445

@volcacius

Description

@volcacius

I encountered a variety of issues while trying to adopt a combination of DistributedDataParallel and DTensor based tensor parallelism. Some specific to DDP+TP, some more general.

This seems to be somewhat known since e.g. torchtitan disallows combining DDP with other parallelization strategies https://github.com/pytorch/torchtitan/blob/c08c9d4962ea843dd786b850d1716861955b9a9f/torchtitan/models/llama3/infra/parallelize.py#L123 but I couldn't find any explicit issue tracking it, hence I put my findings here.

  1. The logic designed to support tensor parallelism within DDP does not operate as documented https://github.com/pytorch/pytorch/commits/v2.7.0/torch/distributed/tensor/parallel/ddp.py . In principle it should:
    A) Convert from DTensor to local tensor at init time.
    B) Convert from local tensor to DTensor in pre-forward.
    C) Convert from DTensor to local tensor in post-forward.
    In practice, once (B) fires for the first time, all parameters are replaced with non-leaf DTensor by
    def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None):
    which means the model becomes parameter-less, which means (C) never fires because it's designed to iterate over parameters
    for name, param in module.named_parameters():
    .
    As a consequence of this, the model appears to be stateless when invoking state_dict(), since all parameters have been deleted, so neither standard state APIs nor DCP work.

Even if the (C) was called correctly, it's not clear to me that allocating a new parameter on every forward would make sense w.r.t. to the optimizer. Doing a swap_tensors is also not possible (last time I checked) since torch.Tensor and DTensor have different slots, so I think the original parameters created in (A) need to be preserved. On top of that, you would have that this tensor hook would be called multiple times

.
Note that while I'm talking specifically about DistributedDataParallel, replicate uses the same conversion hooks to compose with TP
# TODO: This is a temporary work around to enable DDP + TP.

  1. Composition with activation checkpointing is also broken since the two hooks are called once per-model forward, rather than once per-submodule forward. So if activation checkpointing invokes a submodule, DTensor conversion doesn't happen.
    Additionally, even in the scenario where hooks operated per submodule, early stopping being enabled by default is also problematic

    _enable_checkpoint_early_stop = True
    since it can lead to post forward hooks to be skipped (unless they are registered with always_call, which is not the case). To me this appears to be problematic in general for any hook based implementation, not just TP+DDP.

  2. Propagation of requires_grad is fairly inconsistent. It's not propagated at init time here

    nor in any of the TP style (e.g.
    dist_param = nn.Parameter(
    ) and it's also set inplace here unconditionally
    tensor._local_tensor.requires_grad_()
    over the DTensor, for reasons not obvious to me.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions