-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Open
Labels
module: dtensordistributed tensor tagdistributed tensor tagmodule: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
MRE:
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Shard
def main():
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
mesh = init_device_mesh("cuda", (world_size,))
x = torch.randn(1, 16).cuda()
distx = distribute_tensor(x, device_mesh=mesh, placements=[Shard(0)])
rand_int = torch.compile(lambda x: torch.randint_like(x, 0, 16))(distx)
dist.destroy_process_group()
if __name__ == "__main__":
main()
torchrun --nproc_per_node=2 randint_like_mre.py
Above would be useful in compiled functions that require random tensors of the same Tensor subclass. Ideally this compiles and runs without error (even with fullgraph=True
, as opposed to graph breaking). The operation appears to run as expected without compile, returning a DTensor of the correct shape, outputting to the correct placements.
Error logs
[rank0]:V0623 22:19:54.758000 34373 site-packages/torch/_dynamo/convert_frame.py:1055] [0/0] torchdynamo start compiling <lambda> /workspace/dtensor_randint_like.py:19, stack (elided 5 frames):
[rank0]:V0623 22:19:54.758000 34373 site-packages/torch/_dynamo/convert_frame.py:1055] [0/0] File "/workspace/dtensor_randint_like.py", line 24, in <module>
[rank0]:V0623 22:19:54.758000 34373 site-packages/torch/_dynamo/convert_frame.py:1055] [0/0] main()
[rank0]:V0623 22:19:54.758000 34373 site-packages/torch/_dynamo/convert_frame.py:1055] [0/0] File "/workspace/dtensor_randint_like.py", line 19, in main
[rank0]:V0623 22:19:54.758000 34373 site-packages/torch/_dynamo/convert_frame.py:1055] [0/0] rand_int = torch.compile(lambda x: torch.randint_like(x, 0, 16))(distx)
[rank0]:V0623 22:19:54.758000 34373 site-packages/torch/_dynamo/convert_frame.py:1055] [0/0]
[rank0]:I0623 22:19:54.759000 34373 site-packages/torch/_dynamo/symbolic_convert.py:3320] [0/0] Step 1: torchdynamo start tracing <lambda> /workspace/dtensor_randint_like.py:19
[rank0]:I0623 22:19:54.760000 34373 site-packages/torch/fx/experimental/symbolic_shapes.py:3767] [0/0] create_env
[rank0]:V0623 22:19:54.762000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1237] [0/0] [__trace_source] TRACE starts_line /workspace/dtensor_randint_like.py:19 in <lambda> (main)
[rank0]:V0623 22:19:54.762000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1237] [0/0] [__trace_source] rand_int = torch.compile(lambda x: torch.randint_like(x, 0, 16))(distx)
[rank0]:V0623 22:19:54.763000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE RESUME 0 []
[rank0]:V0623 22:19:54.763000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_GLOBAL torch []
[rank0]:V0623 22:19:54.763000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_ATTR randint_like [NullVariable, LazyVariableTracker()]
[rank0]:V0623 22:19:54.765000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_FAST x [NullVariable, LazyVariableTracker()]
[rank0]:V0623 22:19:54.765000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_CONST 0 [NullVariable, LazyVariableTracker(), LazyVariableTracker()]
[rank0]:V0623 22:19:54.765000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_CONST 16 [NullVariable, LazyVariableTracker(), LazyVariableTracker(), ConstantVariable(int: 0)]
[rank0]:V0623 22:19:54.765000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE PRECALL 3 [NullVariable, LazyVariableTracker(), LazyVariableTracker(), ConstantVariable(int: 0), ConstantVariable(int: 16)]
[rank0]:V0623 22:19:54.765000 34373 site-packages/torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE CALL 3 [NullVariable, LazyVariableTracker(), LazyVariableTracker(), ConstantVariable(int: 0), ConstantVariable(int: 16)]
[rank0]:V0623 22:19:54.772000 34373 site-packages/torch/_dynamo/variables/builder.py:3373] [0/0] wrap_to_fake L['x'] (1, 16) SubclassSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], specialize_on=[[], []], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}, inner_contexts={'_local_tensor': StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], specialize_on=[[], []], view_base_context=None, tensor_source=AttrSource(base=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), member='_local_tensor'), shape_env_to_source_to_symbol_cache={})}) <class 'torch.distributed.tensor.DTensor'>
[rank0]:V0623 22:19:54.773000 34373 site-packages/torch/_dynamo/variables/builder.py:3373] [0/0] wrap_to_fake L['x']._local_tensor (1, 16) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], specialize_on=[[], []], view_base_context=None, tensor_source=AttrSource(base=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), member='_local_tensor'), shape_env_to_source_to_symbol_cache={133243555096592: {"L['x']._local_tensor.size()[0]": 1, "L['x']._local_tensor.size()[1]": 16, "L['x']._local_tensor.storage_offset()": 0}}) <class 'torch.Tensor'>
[rank0]:V0623 22:19:54.774000 34373 site-packages/torch/_dynamo/output_graph.py:2614] [0/0] create_graph_input L_x_ L['x'] DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(1, 16)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)) at debug_level 0 before=False
[rank0]:V0623 22:19:54.775000 34373 site-packages/torch/_dynamo/variables/builder.py:3373] [0/0] wrap_to_fake L['x']._local_tensor (1, 16) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], specialize_on=[[], []], view_base_context=None, tensor_source=AttrSource(base=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), member='_local_tensor'), shape_env_to_source_to_symbol_cache={133243555096592: {"L['x']._local_tensor.size()[0]": 1, "L['x']._local_tensor.size()[1]": 16, "L['x']._local_tensor.storage_offset()": 0}}) <class 'torch.Tensor'>
[rank0]:V0623 22:19:54.775000 34373 site-packages/torch/_dynamo/output_graph.py:2614] [0/0] create_graph_input L_x_local_tensor L['x']._local_tensor FakeTensor(..., device='cuda:0', size=(1, 16)) at debug_level 0 before=False
[rank0]:V0623 22:19:54.776000 34373 site-packages/torch/_dynamo/output_graph.py:2462] [0/0] [__trace_call] TRACE FX call randint_like from /workspace/dtensor_randint_like.py:19 in <lambda> (main)
[rank0]:V0623 22:19:54.776000 34373 site-packages/torch/_dynamo/output_graph.py:2462] [0/0] [__trace_call] rand_int = torch.compile(lambda x: torch.randint_like(x, 0, 16))(distx)
[rank0]:V0623 22:19:54.776000 34373 site-packages/torch/_dynamo/output_graph.py:2462] [0/0] [__trace_call] ~~~~~~~~~~~~~~~~~~^^^^^^^^^^
[rank0]:I0623 22:19:54.782000 34373 site-packages/torch/_dynamo/convert_frame.py:1175] [0/0] run_gc_after_compile: running gc
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/dtensor_randint_like.py", line 24, in <module>
[rank0]: main()
[rank0]: File "/workspace/dtensor_randint_like.py", line 19, in main
[rank0]: rand_int = torch.compile(lambda x: torch.randint_like(x, 0, 16))(distx)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 703, in compile_wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1495, in __call__
[rank0]: return self._torchdynamo_orig_callable(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1272, in __call__
[rank0]: result = self._inner_convert(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 629, in __call__
[rank0]: return _compile(
[rank0]: ^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1111, in _compile
[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
[rank0]: return function(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 793, in compile_inner
[rank0]: return _compile_inner(code, one_graph, hooks, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 832, in _compile_inner
[rank0]: out_code = transform_code_object(code, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
[rank0]: transformations(instructions, code_options)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 267, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 753, in transform
[rank0]: tracer.run()
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3497, in run
[rank0]: super().run()
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1363, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1267, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 834, in wrapper
[rank0]: return inner_fn(self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2910, in CALL
[rank0]: self._call(inst)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2904, in _call
[rank0]: self.call_function(fn, args, kwargs)
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1193, in call_function
[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/lazy.py", line 201, in realize_and_forward
[rank0]: return getattr(self.realize(), name)(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/torch.py", line 1338, in call_function
[rank0]: tensor_variable = wrap_fx_proxy(
[rank0]: ^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2559, in wrap_fx_proxy
[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2625, in wrap_fx_proxy_cls
[rank0]: return _wrap_fx_proxy(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2723, in _wrap_fx_proxy
[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3351, in get_fake_value
[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3249, in get_fake_value
[rank0]: ret_val = wrap_fake_exception(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2749, in wrap_fake_exception
[rank0]: return fn()
[rank0]: ^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3250, in <lambda>
[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3458, in run_node
[rank0]: raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 3417, in run_node
[rank0]: return node.target(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_compile.py", line 53, in inner
[rank0]: return disable_fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 896, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 350, in __torch_dispatch__
[rank0]: return DTensor._op_dispatcher.dispatch(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 184, in dispatch
[rank0]: random._rng_tracker = random.OffsetBasedRNGTracker(mesh)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_random.py", line 179, in __init__
[rank0]: rng_state = self._device_handle.get_rng_state().to(self._device)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_stats.py", line 28, in wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
[rank0]: return self.dispatch(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2058, in dispatch
[rank0]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1457, in _cached_dispatch_impl
[rank0]: return self._dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2352, in _dispatch_impl
[rank0]: (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2803, in validate_and_convert_non_fake_tensors
[rank0]: validated_args = [validate(a) for a in flat_args]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2803, in <listcomp>
[rank0]: validated_args = [validate(a) for a in flat_args]
[rank0]: ^^^^^^^^^^^
[rank0]: File "/opt/conda/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2791, in validate
[rank0]: raise AssertionError(
[rank0]: torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <built-in method randint_like of type object at 0x7930684c93a0>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(1, 16)), device_mesh=DeviceMesh('cuda', [0, 1]), placements=(Shard(dim=0),)), 0, 16), **{}): got AssertionError("Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.to.dtype_layout(tensor([...], size=(16,), dtype=torch.uint8), dtype=torch.uint8, layout=torch.strided, device=device(type='cuda', index=0))")
[rank0]: from user code:
[rank0]: File "/workspace/dtensor_randint_like.py", line 19, in <lambda>
[rank0]: rand_int = torch.compile(lambda x: torch.randint_like(x, 0, 16))(distx)
Versions
PyTorch version: 2.8.0.dev20250623+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 4.0.3
Libc version: glibc-2.35
Python version: 3.11.13 | packaged by conda-forge | (main, Jun 4 2025, 14:48:23) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-1024-aws-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.8.93
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA L40S
GPU 1: NVIDIA L40S
GPU 2: NVIDIA L40S
GPU 3: NVIDIA L40S
Nvidia driver version: 570.133.20
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 48
On-line CPU(s) list: 0-47
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7R13 Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 1
Stepping: 1
BogoMIPS: 5299.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save vaes vpclmulqdq rdpid
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 768 KiB (24 instances)
L1i cache: 768 KiB (24 instances)
L2 cache: 12 MiB (24 instances)
L3 cache: 96 MiB (3 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-47
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.3.0
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.3
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] optree==0.16.0
[pip3] pytorch-triton==3.3.1+gitc8757738
[pip3] torch==2.8.0.dev20250623+cu128
[pip3] torchao==0.12.0+git28989031
[pip3] torchaudio==2.8.0.dev20250623+cu128
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.23.0.dev20250623+cu128
[conda] numpy 2.3.0 py311h519dc76_0 conda-forge
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.3 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] optree 0.16.0 pypi_0 pypi
[conda] pytorch-triton 3.3.1+gitc8757738 pypi_0 pypi
[conda] torch 2.8.0.dev20250623+cu128 pypi_0 pypi
[conda] torchao 0.12.0+git28989031 dev_0 <develop>
[conda] torchaudio 2.8.0.dev20250623+cu128 pypi_0 pypi
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torchvision 0.23.0.dev20250623+cu128 pypi_0 pypi
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @tianyu-l @XilunWu
Metadata
Metadata
Assignees
Labels
module: dtensordistributed tensor tagdistributed tensor tagmodule: dynamooncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queueoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module