Skip to content

Commit 42d2205

Browse files
committed
[DTensor] Assert DTensorSpec has valid placements
This helped identify buggy sharding rules during debugging, why not check it in. ghstack-source-id: 3fde8ba Pull Request resolved: #158133
1 parent 709139c commit 42d2205

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_dtensor_constructor_w_graph_break(self):
298298
x = torch.randn(64, 32, requires_grad=True)
299299
spec = DTensorSpec(
300300
mesh,
301-
(Replicate(), Shard(0)),
301+
(Replicate(),),
302302
tensor_meta=TensorMeta(
303303
shape=torch.Size([128, 32]), stride=(32, 1), dtype=x.dtype
304304
),

torch/distributed/tensor/_dtensor_spec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class DTensorSpec:
3232
def __post_init__(self) -> None:
3333
if not isinstance(self.placements, tuple):
3434
self.placements = tuple(self.placements)
35+
if not len(self.placements) == self.mesh.ndim:
36+
raise ValueError(
37+
f"DTensorSpec requires one placement per mesh dim (mesh.ndim={self.mesh.ndim}), got {self.placements=}"
38+
)
3539
self._hash: Optional[int] = None
3640

3741
def __setattr__(self, attr: str, value: Any) -> None:

0 commit comments

Comments
 (0)