diff --git a/test/distributed/tensor/test_redistribute.py b/test/distributed/tensor/test_redistribute.py index fe07b0dd6a24..8c5e30c2b137 100644 --- a/test/distributed/tensor/test_redistribute.py +++ b/test/distributed/tensor/test_redistribute.py @@ -695,5 +695,71 @@ def test_redistribute_shard_dim_multi_dim_mesh(self): self.assertEqual(local_out_dt, local_expected_dt) +class DeviceOrderRedistributeTest(DTensorTestBase): + @property + def world_size(self) -> int: + return 8 + + @with_comms + def test_redistribute_mesh_dim_reorder(self): + mesh = init_device_mesh(self.device_type, (2, 2, 2)) + input_data = torch.randn((8, 8, 8), device=self.device_type) + sharding_src_dst_pairs_with_order = [ + # after reodering: S(0)S(0)S(0) -> RS(0)S(0) (S(0) to mesh axis + # I_{0,1,2}->I_{1,2}). 2: S0->R, 1: S0->R, 0: S0->R, 1: R->S0, 2: + # R->S0 + ( + ([Shard(0), Shard(0), Shard(0)], [0, 1, 2]), + ([Replicate(), Shard(0), Shard(0)], [0, 1, 2]), + ), + # same as above, device order will be default to [0,1,2] if not + # specified + ( + ([Shard(0), Shard(0), Shard(0)], None), + ([Replicate(), Shard(0), Shard(0)], None), + ), + # after reodering: S(0)S(0)S(0) -> RS(0)S(0) (S(0) to mesh axis + # I_{1,0,2}->I_{1,2}). 2: S0->R, 0: S0->R, 2: R->S0 + ( + ([Shard(0), Shard(0), Shard(0)], [1, 0, 2]), + ([Replicate(), Shard(0), Shard(0)], [0, 1, 2]), + ), + # after reodering: S(0)S(0)S(0) -> S(0)S(0)R (S(0) to mesh axis + # I_{0,1,2}->I_{0,1}). 2: S0->R + ( + ([Shard(0), Shard(0), Shard(0)], [0, 1, 2]), + ([Replicate(), Shard(0), Shard(0)], [2, 0, 1]), + ), + # after reodering: RS(0)S(0) -> RS(1)S(0). (S(0) to mesh axis + # I_{1,2}->I_{2}) 2: S0->R, 1: S0->R, 2: R->S0, 1: R->S1 + # TODO: this can be optimized to replace one allreduce to alltoall. + ( + ([Replicate(), Shard(0), Shard(0)], [0, 1, 2]), + ([Shard(1), Shard(0), Replicate()], [1, 2, 0]), + ), + ] + excepted_comm_counts = [3, 3, 2, 1, 2] + comm_mode = CommDebugMode() + for idx, ((src_placement, src_order), (dst_placement, dst_order)) in enumerate( + sharding_src_dst_pairs_with_order + ): + sharded_dt = distribute_tensor( + input_data, mesh, src_placement, device_order=src_order + ) + expected_dt = distribute_tensor( + input_data.clone(), mesh, dst_placement, device_order=dst_order + ) + with comm_mode: + out_dt = sharded_dt.redistribute( + mesh, dst_placement, device_order=dst_order + ) + self.assertEqual( + comm_mode.get_total_counts(), excepted_comm_counts[idx] + ) + local_out_dt = out_dt.to_local() + local_expected_dt = expected_dt.to_local() + self.assertEqual(local_out_dt, local_expected_dt) + + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index b0ee136c135f..4f98d27aa45c 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -487,6 +487,7 @@ def redistribute( self, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, + device_order: Optional[Sequence[int]] = None, *, async_op: bool = False, forward_dtype: Optional[torch.dtype] = None, @@ -562,7 +563,13 @@ def redistribute( # pyre-fixme[16]: `Redistribute` has no attribute `apply`. return Redistribute.apply( - self, device_mesh, placements, async_op, forward_dtype, backward_dtype + self, + device_mesh, + placements, + device_order, + async_op, + forward_dtype, + backward_dtype, ) def full_tensor( @@ -662,6 +669,7 @@ def distribute_tensor( tensor: torch.Tensor, device_mesh: Optional[DeviceMesh] = None, placements: Optional[Sequence[Placement]] = None, + device_order: Optional[Sequence[int]] = None, *, src_data_rank: Optional[int] = 0, ) -> DTensor: @@ -761,10 +769,10 @@ def distribute_tensor( local_tensor = tensor.detach() - # TODO(xilun): address sharding order - # distribute the tensor according to the placements. placements = list(placements) - for idx, placement in enumerate(placements): + device_order = device_order or list(range(device_mesh.ndim)) + assert len(device_order) == device_mesh.ndim + for idx, placement in zip(device_order, placements): if placement.is_shard(): placement = cast(Shard, placement) if placement.dim < 0: @@ -791,6 +799,7 @@ def distribute_tensor( spec = DTensorSpec( mesh=device_mesh, placements=placements, + device_order=tuple(device_order), tensor_meta=TensorMeta( shape=tensor.size(), stride=tensor.stride(), diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index eb528ee4f9af..d2befaeda299 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -25,13 +25,18 @@ class TensorMeta(NamedTuple): class DTensorSpec: mesh: DeviceMesh placements: tuple[Placement, ...] - # tensor meta will only be set during sharding propagation tensor_meta: Optional[TensorMeta] = None + # device order is used to specify the order of the device mesh, range(0, mesh.ndim) + device_order: Optional[tuple[int, ...]] = None def __post_init__(self) -> None: if not isinstance(self.placements, tuple): self.placements = tuple(self.placements) + if not self.device_order: + self.device_order = tuple(range(self.mesh.ndim)) + if not isinstance(self.device_order, tuple): + self.device_order = tuple(self.device_order) self._hash: Optional[int] = None def __setattr__(self, attr: str, value: Any) -> None: @@ -55,6 +60,7 @@ def _hash_impl(self) -> int: self.tensor_meta.shape, self.tensor_meta.stride, self.tensor_meta.dtype, + self.device_order, ) ) return hash((self.mesh, self.placements)) @@ -82,6 +88,7 @@ def __eq__(self, other: object, /) -> bool: self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr] and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr] and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr] + and self.device_order == other.device_order # type: ignore[union-attr] ) def __str__(self) -> str: diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 11fc2d11e1a8..0fc10650d7b9 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -30,6 +30,8 @@ class _TransformInfo(NamedTuple): def _gen_transform_infos_non_cached( src_spec: DTensorSpec, dst_spec: DTensorSpec, + src_device_order: tuple[int, ...], + dst_device_order: tuple[int, ...], ) -> list[_TransformInfo]: """ Generate the transform infos from the source placements to the target placements. @@ -51,7 +53,6 @@ def _gen_transform_infos_non_cached( # logical shape records the logic tensor shape on the mesh dimension # this is useful to ensure uneven sharding gets correct output shape initial_logical_shape = list(src_spec.shape) - mesh_dims_to_logical_shape = [initial_logical_shape] if device_mesh.ndim == 1: # if device_mesh is 1D, redistribute is a simple direct transformation @@ -64,84 +65,210 @@ def _gen_transform_infos_non_cached( ) return transform_infos - # Handle multi-dim device mesh placement redistribution - # First, we need to build the logical shape for each mesh dim - # for correct allgathering uneven shards on each mesh dim (with dynamic padding) - for i, src in enumerate(src_spec.placements): - current_logical_shape = mesh_dims_to_logical_shape[i] - if isinstance(src, Shard): - if i < device_mesh.ndim - 1: - # calculate and save the logical shape for this sharding - mesh_dim_size = device_mesh.size(mesh_dim=i) - local_shard_size, _ = src._local_shard_size_and_offset( - current_logical_shape[src.dim], - mesh_dim_size, - my_coordinate[i], - ) - new_logical_shape = list(current_logical_shape) - new_logical_shape[src.dim] = local_shard_size - mesh_dims_to_logical_shape.append(new_logical_shape) - else: - mesh_dims_to_logical_shape.append(current_logical_shape) - - # Next, we need to derive the transform infos from src to dst placements, - # here we use a greedy search with step by step state transformations - current_placements = list(src_spec.placements) - target_placements = list(dst_spec.placements) - - if src_spec.num_shards > 1: - # If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec - # a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))). - # In those cases, we first traverse from inner placement to outer placement - # to detect misaligned shardings and properly replicate nested sharding first. - for mesh_dim in reversed(range(len(current_placements))): - current = current_placements[mesh_dim] - target = target_placements[mesh_dim] - # If target is not Shard, we can directly redistribute since we are traversing from innner - # to outer placements here - if isinstance(target, Shard): - # If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim - shard_dim = target.dim - current_mesh_sharding, target_mesh_sharding = [], [] - for i, (s, p) in enumerate(zip(current_placements, target_placements)): - if i >= mesh_dim: - break - if s.is_shard(shard_dim): - current_mesh_sharding.append(i) - if p.is_shard(shard_dim): - target_mesh_sharding.append(i) + # sort the src_spec based on src_device_order + def _reorder_placement(placements, device_order): + return [placement for _, placement in sorted(zip(device_order, placements))] + + sorted_src_placement = _reorder_placement(src_spec.placements, src_device_order) + sorted_dst_placement = _reorder_placement(dst_spec.placements, dst_device_order) + + # map sharded tensor dim to device mesh dim with device ordering + def _map_tensor_dim_to_mesh_dim(placements, device_order): + tensor_dim_to_mesh_dims: dict[int, list[int]] = {} + for placement, mesh_dim in zip(placements, device_order): + if placement.is_shard(): + assert isinstance(placement, Shard) + if placement.dim not in tensor_dim_to_mesh_dims: + tensor_dim_to_mesh_dims[placement.dim] = [] + tensor_dim_to_mesh_dims[placement.dim].append(mesh_dim) + return tensor_dim_to_mesh_dims + + src_tensor_dim_to_mesh_dims = _map_tensor_dim_to_mesh_dim( + src_spec.placements, src_device_order + ) + dst_device_order_to_mesh_dims = _map_tensor_dim_to_mesh_dim( + dst_spec.placements, dst_device_order + ) - if current_mesh_sharding != target_mesh_sharding: - # if current/target_placements have misaligned sharding on the tensor dim BEFORE the current - # mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding - target = Replicate() + # derive the logical shape of src_spec on each mesh dim + final_logical_shape = list(src_spec.shape) + # for mesh_dim, placement in zip(src_device_order, src_spec.placements): + for mesh_dim, placement in enumerate(sorted_src_placement): + if placement.is_shard(): + assert isinstance(placement, Shard) + mesh_dim_size = device_mesh.size(mesh_dim=mesh_dim) + local_shard_size, _ = placement._local_shard_size_and_offset( + final_logical_shape[placement.dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + final_logical_shape[placement.dim] = local_shard_size + # now final_logical_shape is the final shape under src_spec.placements. + current_logical_shape = list(final_logical_shape) + for tensor_dim in range(src_spec.ndim): + if ( + tensor_dim in src_tensor_dim_to_mesh_dims + and tensor_dim in dst_device_order_to_mesh_dims + ): + # The rule is to only allow push/pop the rightmost number to turn + # src_tensor_dim_to_mesh_dims[tensor_dim] into + # dst_device_order_to_mesh_dims[tensor_dim]. For example, if + # src_tensor_dim_to_mesh_dims[tensor_dim] = [3, 0, 1, 2] and + # dst_device_order_to_mesh_dims[tensor_dim] = [3, 1, 0], we need to + # [3, 0, 1, 2] -> (allgather) [3, 0, 1] -> (allgather) [3, 0] -> + # (allgather) [3] -> (chunk) [3, 1] -> (chunk) [3, 1, 0] + src_mesh_dims = ( + src_tensor_dim_to_mesh_dims[tensor_dim].copy() + if tensor_dim in src_tensor_dim_to_mesh_dims + else [] + ) + dst_mesh_dims = ( + dst_device_order_to_mesh_dims[tensor_dim].copy() + if tensor_dim in dst_device_order_to_mesh_dims + else [] + ) + + # find i s.t. src_tensor_dim_to_mesh_dims[:i]==dst_device_order_to_mesh_dims[:i] + i = 0 + for i in range(len(src_mesh_dims)): + if i < len(dst_mesh_dims): + if src_mesh_dims[i] != dst_mesh_dims[i]: + break - if current != target: + # Build the transform_infos for those gathering operation. + while len(src_mesh_dims) > i: + mesh_dim = src_mesh_dims.pop() + # allgather on the popped mesh_dim + mesh_dim_size = device_mesh.size(mesh_dim=mesh_dim) transform_infos.append( _TransformInfo( mesh_dim=mesh_dim, - src_dst_placements=(current, target), - logical_shape=mesh_dims_to_logical_shape[mesh_dim], + src_dst_placements=(Shard(tensor_dim), Replicate()), + logical_shape=current_logical_shape, ) ) - current_placements[mesh_dim] = target - # We always traverse from outer placement to inner placement to collect the remaining - # needed transform infos (i.e. the replication from nested sharding might need to further - # perform resharding to Shard again) - for mesh_dim, (current, target) in enumerate( - zip(current_placements, target_placements) + current_logical_shape[tensor_dim] = min( + current_logical_shape[tensor_dim] * mesh_dim_size, + src_spec.shape[tensor_dim], + ) + + assert len(src_mesh_dims) == i + # Build the transform_infos for those chunk operation. + for mesh_dim in dst_mesh_dims[i:]: + # chunk on mesh_dim + mesh_dim_size = device_mesh.size(mesh_dim=mesh_dim) + current_placement = sorted_dst_placement[mesh_dim] + assert isinstance(current_placement, Shard) + local_shard_size, _ = current_placement._local_shard_size_and_offset( + current_logical_shape[tensor_dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(Replicate(), Shard(tensor_dim)), + logical_shape=current_logical_shape, + ) + ) + current_logical_shape[tensor_dim] = local_shard_size + # update all Shard(tensor_dim) to Replicate() + for mesh_dim, placement in enumerate(sorted_src_placement): + if isinstance(placement, Shard) and placement.dim == tensor_dim: + sorted_src_placement[mesh_dim] = Replicate() + for mesh_dim, placement in enumerate(sorted_dst_placement): + if isinstance(placement, Shard) and placement.dim == tensor_dim: + sorted_dst_placement[mesh_dim] = Replicate() + + elif tensor_dim in src_tensor_dim_to_mesh_dims: + # Check if exist Shard() to Shard() pattern I_x -> J_x. In this case + # we can apply alltoall. x is mesh dim + if len(src_tensor_dim_to_mesh_dims[tensor_dim]) == 1: + mesh_dim = src_tensor_dim_to_mesh_dims[tensor_dim][0] + # check if exist j s.t. dst_tensor_dim_to_mesh_dims[j]==[mesh_dim] + for j in range(src_spec.ndim): + if ( + j in dst_device_order_to_mesh_dims + and dst_device_order_to_mesh_dims[j] == [mesh_dim] + ): + mesh_dim_size = device_mesh.size(mesh_dim=mesh_dim) + current_placement = sorted_dst_placement[mesh_dim] + assert isinstance(current_placement, Shard) + # alltoall from Shard(tensor_dim) to Shard() + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(Shard(tensor_dim), Shard(j)), + logical_shape=current_logical_shape, + ) + ) + local_shard_size, _ = ( + current_placement._local_shard_size_and_offset( + current_logical_shape[tensor_dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + ) + # TODO: confirm if this is correct + current_logical_shape[tensor_dim] = local_shard_size + current_logical_shape[j] = min( + current_logical_shape[j] * mesh_dim_size, src_spec.shape[j] + ) + sorted_src_placement[mesh_dim] = Shard(j) + # just use the first matching one, delete key j from dst_device_order_to_mesh_dims to prevent reuse + del dst_device_order_to_mesh_dims[ + j + ] # may not be necessary, just to be safe + break + # We have done processing Shard()->Shard() case, now process the remaining mesh dim + for mesh_dim, (src_placement, dst_placement) in enumerate( + zip(sorted_src_placement, sorted_dst_placement) ): - if current != target: + if src_placement == dst_placement: + continue + assert not ( + isinstance(src_placement, Shard) and isinstance(dst_placement, Shard) + ) + mesh_dim_size = device_mesh.size(mesh_dim=mesh_dim) + if isinstance(src_placement, Shard): + # shard -> replicate/partial transform_infos.append( _TransformInfo( mesh_dim=mesh_dim, - src_dst_placements=(current, target), - logical_shape=mesh_dims_to_logical_shape[mesh_dim], + src_dst_placements=(src_placement, dst_placement), + logical_shape=current_logical_shape, + ) + ) + current_logical_shape[src_placement.dim] = min( + current_logical_shape[src_placement.dim] * mesh_dim_size, + src_spec.shape[src_placement.dim], + ) + elif isinstance(dst_placement, Shard): + # replicate/partial -> shard + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(src_placement, dst_placement), + logical_shape=current_logical_shape, ) ) - current_placements[mesh_dim] = target + local_shard_size, _ = dst_placement._local_shard_size_and_offset( + current_logical_shape[dst_placement.dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + current_logical_shape[dst_placement.dim] = local_shard_size + else: + # replicate/partial -> replicate/partial + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(src_placement, dst_placement), + logical_shape=current_logical_shape, + ) + ) return transform_infos @@ -149,14 +276,20 @@ def _gen_transform_infos_non_cached( def _gen_transform_infos( src_spec: DTensorSpec, dst_spec: DTensorSpec, + src_device_order: tuple[int, ...], + dst_device_order: tuple[int, ...], ) -> list[_TransformInfo]: - return _gen_transform_infos_non_cached(src_spec, dst_spec) + return _gen_transform_infos_non_cached( + src_spec, dst_spec, src_device_order, dst_device_order + ) def redistribute_local_tensor( local_tensor: torch.Tensor, current_spec: DTensorSpec, target_spec: DTensorSpec, + src_device_order: Optional[tuple[int, ...]] = None, + dst_device_order: Optional[tuple[int, ...]] = None, *, async_op: bool = False, is_backward: bool = False, @@ -171,6 +304,16 @@ def redistribute_local_tensor( # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same raise NotImplementedError("Cross device mesh comm not supported yet!") + if not src_device_order: + src_device_order = tuple(range(current_spec.device_mesh.ndim)) + if not dst_device_order: + dst_device_order = tuple(range(target_spec.device_mesh.ndim)) + + if not isinstance(src_device_order, tuple): + src_device_order = tuple(src_device_order) + if not isinstance(dst_device_order, tuple): + dst_device_order = tuple(dst_device_order) + new_local_tensor = local_tensor device_mesh = current_spec.mesh @@ -185,9 +328,13 @@ def redistribute_local_tensor( isinstance(s, torch.SymInt) for s in target_spec.shape ) if has_symints: - transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + transform_infos = _gen_transform_infos_non_cached( + current_spec, target_spec, src_device_order, dst_device_order + ) else: - transform_infos = _gen_transform_infos(current_spec, target_spec) + transform_infos = _gen_transform_infos( + current_spec, target_spec, src_device_order, dst_device_order + ) for transform_info in transform_infos: i = transform_info.mesh_dim @@ -288,6 +435,7 @@ def forward( # type: ignore[override] input: "dtensor.DTensor", device_mesh: DeviceMesh, placements: tuple[Placement, ...], + device_order: Optional[tuple[int, ...]] = None, async_op: bool = False, forward_dtype: Optional[torch.dtype] = None, backward_dtype: Optional[torch.dtype] = None, @@ -295,12 +443,14 @@ def forward( # type: ignore[override] ctx.async_op = async_op ctx.backward_dtype = backward_dtype ctx.original_dtype = input._local_tensor.dtype + ctx.original_device_order = input._spec.device_order if forward_dtype is not None and forward_dtype != input._local_tensor.dtype: local_tensor = input._local_tensor.to(dtype=forward_dtype) current_spec = DTensorSpec( mesh=device_mesh, placements=input._spec.placements, + device_order=input._spec.device_order, tensor_meta=TensorMeta( shape=input.shape, stride=input.stride(), @@ -315,11 +465,19 @@ def forward( # type: ignore[override] if current_spec.placements != placements: target_spec = DTensorSpec( - device_mesh, placements, tensor_meta=current_spec.tensor_meta + device_mesh, + placements, + device_order=device_order, + tensor_meta=current_spec.tensor_meta, ) output = redistribute_local_tensor( - local_tensor, current_spec, target_spec, async_op=async_op + local_tensor, + current_spec, + target_spec, + src_device_order=input._spec.device_order, + dst_device_order=device_order, + async_op=async_op, ) else: # use the same local tensor if placements are the same. @@ -362,6 +520,8 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] local_tensor, current_spec, previous_spec, + src_device_order=current_spec.device_order, + dst_device_order=previous_spec.device_order, async_op=async_op, is_backward=True, ) @@ -400,4 +560,5 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] None, None, None, + None, )