Skip to content

Commit 52cbd46

Browse files
committed
Support of dtensor redistribute with device order
ghstack-source-id: c5dd358 Pull Request resolved: #160266
1 parent 24257f5 commit 52cbd46

File tree

5 files changed

+317
-75
lines changed

5 files changed

+317
-75
lines changed

test/distributed/tensor/test_redistribute.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,5 +695,71 @@ def test_redistribute_shard_dim_multi_dim_mesh(self):
695695
self.assertEqual(local_out_dt, local_expected_dt)
696696

697697

698+
class DeviceOrderRedistributeTest(DTensorTestBase):
699+
@property
700+
def world_size(self) -> int:
701+
return 8
702+
703+
@with_comms
704+
def test_redistribute_mesh_dim_reorder(self):
705+
mesh = init_device_mesh(self.device_type, (2, 2, 2))
706+
input_data = torch.randn((8, 8, 8), device=self.device_type)
707+
sharding_src_dst_pairs_with_order = [
708+
# after reodering: S(0)S(0)S(0) -> RS(0)S(0) (S(0) to mesh axis
709+
# I_{0,1,2}->I_{1,2}). 2: S0->R, 1: S0->R, 0: S0->R, 1: R->S0, 2:
710+
# R->S0
711+
(
712+
([Shard(0), Shard(0), Shard(0)], [0, 1, 2]),
713+
([Replicate(), Shard(0), Shard(0)], [0, 1, 2]),
714+
),
715+
# same as above, device order will be default to [0,1,2] if not
716+
# specified
717+
(
718+
([Shard(0), Shard(0), Shard(0)], None),
719+
([Replicate(), Shard(0), Shard(0)], None),
720+
),
721+
# after reodering: S(0)S(0)S(0) -> RS(0)S(0) (S(0) to mesh axis
722+
# I_{1,0,2}->I_{1,2}). 2: S0->R, 0: S0->R, 2: R->S0
723+
(
724+
([Shard(0), Shard(0), Shard(0)], [1, 0, 2]),
725+
([Replicate(), Shard(0), Shard(0)], [0, 1, 2]),
726+
),
727+
# after reodering: S(0)S(0)S(0) -> S(0)S(0)R (S(0) to mesh axis
728+
# I_{0,1,2}->I_{0,1}). 2: S0->R
729+
(
730+
([Shard(0), Shard(0), Shard(0)], [0, 1, 2]),
731+
([Replicate(), Shard(0), Shard(0)], [2, 0, 1]),
732+
),
733+
# after reodering: RS(0)S(0) -> RS(1)S(0). (S(0) to mesh axis
734+
# I_{1,2}->I_{2}) 2: S0->R, 1: S0->R, 2: R->S0, 1: R->S1
735+
# TODO: this can be optimized to replace one allreduce to alltoall.
736+
(
737+
([Replicate(), Shard(0), Shard(0)], [0, 1, 2]),
738+
([Shard(1), Shard(0), Replicate()], [1, 2, 0]),
739+
),
740+
]
741+
excepted_comm_counts = [3, 3, 2, 1, 2]
742+
comm_mode = CommDebugMode()
743+
for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate(
744+
sharding_src_dst_pairs_with_order
745+
):
746+
sharded_dt = distribute_tensor(
747+
input_data, mesh, src_placement, device_order=src_order
748+
)
749+
expected_dt = distribute_tensor(
750+
input_data.clone(), mesh, dst_placement, device_order=dst_order
751+
)
752+
with comm_mode:
753+
out_dt = sharded_dt.redistribute(
754+
mesh, dst_placement, device_order=dst_order
755+
)
756+
self.assertEqual(
757+
comm_mode.get_total_counts(), excepted_comm_counts[idx]
758+
)
759+
local_out_dt = out_dt.to_local()
760+
local_expected_dt = expected_dt.to_local()
761+
self.assertEqual(local_out_dt, local_expected_dt)
762+
763+
698764
if __name__ == "__main__":
699765
run_tests()

torch/distributed/tensor/_api.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,7 @@ def redistribute(
487487
self,
488488
device_mesh: Optional[DeviceMesh] = None,
489489
placements: Optional[Sequence[Placement]] = None,
490+
device_order: Optional[Sequence[int]] = None,
490491
*,
491492
async_op: bool = False,
492493
forward_dtype: Optional[torch.dtype] = None,
@@ -562,7 +563,13 @@ def redistribute(
562563

563564
# pyre-fixme[16]: `Redistribute` has no attribute `apply`.
564565
return Redistribute.apply(
565-
self, device_mesh, placements, async_op, forward_dtype, backward_dtype
566+
self,
567+
device_mesh,
568+
placements,
569+
device_order,
570+
async_op,
571+
forward_dtype,
572+
backward_dtype,
566573
)
567574

568575
def full_tensor(
@@ -662,6 +669,7 @@ def distribute_tensor(
662669
tensor: torch.Tensor,
663670
device_mesh: Optional[DeviceMesh] = None,
664671
placements: Optional[Sequence[Placement]] = None,
672+
device_order: Optional[Sequence[int]] = None,
665673
*,
666674
src_data_rank: Optional[int] = 0,
667675
) -> DTensor:
@@ -761,10 +769,10 @@ def distribute_tensor(
761769

762770
local_tensor = tensor.detach()
763771

764-
# TODO(xilun): address sharding order
765-
# distribute the tensor according to the placements.
766772
placements = list(placements)
767-
for idx, placement in enumerate(placements):
773+
device_order = device_order or list(range(device_mesh.ndim))
774+
assert len(device_order) == device_mesh.ndim
775+
for idx, placement in zip(device_order, placements):
768776
if placement.is_shard():
769777
placement = cast(Shard, placement)
770778
if placement.dim < 0:
@@ -791,6 +799,7 @@ def distribute_tensor(
791799
spec = DTensorSpec(
792800
mesh=device_mesh,
793801
placements=placements,
802+
device_order=tuple(device_order),
794803
tensor_meta=TensorMeta(
795804
shape=tensor.size(),
796805
stride=tensor.stride(),

torch/distributed/tensor/_dtensor_spec.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,18 @@ class TensorMeta(NamedTuple):
2525
class DTensorSpec:
2626
mesh: DeviceMesh
2727
placements: tuple[Placement, ...]
28-
2928
# tensor meta will only be set during sharding propagation
3029
tensor_meta: Optional[TensorMeta] = None
30+
# device order is used to specify the order of the device mesh, range(0, mesh.ndim)
31+
device_order: Optional[tuple[int, ...]] = None
3132

3233
def __post_init__(self) -> None:
3334
if not isinstance(self.placements, tuple):
3435
self.placements = tuple(self.placements)
36+
if not self.device_order:
37+
self.device_order = tuple(range(self.mesh.ndim))
38+
if not isinstance(self.device_order, tuple):
39+
self.device_order = tuple(self.device_order)
3540
self._hash: Optional[int] = None
3641

3742
def __setattr__(self, attr: str, value: Any) -> None:
@@ -55,6 +60,7 @@ def _hash_impl(self) -> int:
5560
self.tensor_meta.shape,
5661
self.tensor_meta.stride,
5762
self.tensor_meta.dtype,
63+
self.device_order,
5864
)
5965
)
6066
return hash((self.mesh, self.placements))
@@ -82,6 +88,7 @@ def __eq__(self, other: object, /) -> bool:
8288
self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr]
8389
and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr]
8490
and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr]
91+
and self.device_order == other.device_order # type: ignore[union-attr]
8592
)
8693

8794
def __str__(self) -> str:

0 commit comments

Comments
 (0)