Skip to content

fix type documentation for context_parallel no_restore_buffers, to prevent user from passing in the wrong type #159808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ad8e
Copy link
Contributor

@ad8e ad8e commented Aug 4, 2025

If you pass in a list, such as in the following testcase, you'll receive RuntimeError: Boolean value of Tensor with more than one value is ambiguous:

# torchrun --standalone --nnodes=1 --nproc-per-node=1 this_file.py
import os

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.nn.attention import sdpa_kernel, SDPBackend


def context_parallel_sdpa_example(world_size: int, rank: int):
    assert torch.cuda.is_available()
    assert dist.is_nccl_available()
    torch.cuda.set_device(f"cuda:{rank}")
    torch.cuda.manual_seed(0)

    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank,
    )
    device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",))

    batch = 1
    nheads = 1
    qkv_len = 4
    dim = 8
    backend = SDPBackend.FLASH_ATTENTION
    dtype = torch.bfloat16 if backend == SDPBackend.FLASH_ATTENTION or backend == SDPBackend.CUDNN_ATTENTION else torch.float32

    qkv = [
        torch.rand(
            (batch, nheads, qkv_len, dim),
            dtype=dtype,
            requires_grad=True,
            device="cuda",
        )
        for _ in range(3)
    ]
    # specify the SDPBackend to use
    with sdpa_kernel(backend):
        out = F.scaled_dot_product_attention(*qkv, is_causal=True)

    # make a clean copy of QKV for output comparison
    cp_qkv = [t.detach().clone() for t in qkv]
    cp_qkv = tuple(cp_qkv)

    with sdpa_kernel(backend):
        # This `context_parallel()` performs two actions:
        # 1. Shard the tensor objects in `buffers` in-place along the dimension
        #    specified in `buffer_seq_dims`, the tensors in `buffers` and their
        #    sharding dims in `buffer_seq_dims` are organized in the same order.
        # 2. Replace the execution of `F.scaled_dot_product_attention` with a
        #    context-paralleled-enabled Ring Attention.
        with context_parallel(device_mesh, buffers=cp_qkv, buffer_seq_dims=(2, 2, 2), no_restore_buffers=cp_qkv):
            cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)

        # The output `cp_out` is still sharded in the same way as QKV
        # the `context_parallel_unshard` API allows users to easily
        # unshard to gain the full tensor.
        (cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])

    assert torch.allclose(
        cp_out,
        out,
        atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
    )


if __name__ == "__main__":
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    try:
        context_parallel_sdpa_example(world_size, rank)
    finally:
        dist.barrier()
        dist.destroy_process_group()

Changing no_restore_buffers=cp_qkv to no_restore_buffers=set(cp_qkv) fixes it.

Another viable option is to reject this patch and instead change no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers to no_restore_buffers = set() if no_restore_buffers is None else set(no_restore_buffers). Other users are likely to make the same mistake of passing in a list.

I assume "b in list" is doing a tensor-wise truth comparison, while "b in set" is somehow not.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159808

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit f0c3fbb with merge base a7f3bdf (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

  • pull / linux-jammy-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge, unstable) (gh) (#158876)
    /var/lib/jenkins/workspace/xla/torch_xla/csrc/runtime/BUILD:476:14: Compiling torch_xla/csrc/runtime/xla_util_test.cpp failed: (Exit 1): gcc failed: error executing CppCompile command (from target //torch_xla/csrc/runtime:xla_util_test) /usr/bin/gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections ... (remaining 229 arguments skipped)

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 4, 2025
@ad8e
Copy link
Contributor Author

ad8e commented Aug 4, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Aug 4, 2025
@janeyx99 janeyx99 requested review from H-Huang and kwen2501 August 4, 2025 22:56
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 4, 2025
@H-Huang
Copy link
Member

H-Huang commented Aug 5, 2025

cc @XilunWu

@H-Huang H-Huang added the module: context parallel PyTorch Context Parallel label Aug 5, 2025
@H-Huang H-Huang requested a review from XilunWu August 5, 2025 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: context parallel PyTorch Context Parallel oncall: distributed Add this issue/PR to distributed oncall triage queue open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants