# mypy: allow-untyped-defs import copy from typing import Any, cast, Optional import torch import torch.distributed as dist import torch.distributed._shard.sharding_spec as shard_spec import torch.distributed.distributed_c10d as c10d from torch.distributed._shard.sharded_tensor import ( Shard, ShardedTensor, ShardedTensorMetadata, TensorProperties, ) from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec from torch.distributed.device_mesh import _mesh_resources from torch.distributed.fsdp._common_utils import _set_fsdp_flattened from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor from torch.distributed.remote_device import _remote_device from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard from torch.distributed.tensor.parallel._data_parallel_utils import ( _flatten_tensor, _unflatten_tensor, ) __all__ = ["DTensorExtensions"] def _get_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: device_mesh = tensor.device_mesh assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" placement = tensor.placements[0] offsets = [0] * len(tensor.size()) num_chunks = device_mesh.size(mesh_dim=0) if tensor.placements[0].is_shard(): shard_dim = cast(DShard, placement).dim chunk_size = tensor.size(shard_dim) // num_chunks offsets[shard_dim] = chunk_size return (torch.Size(offsets), tensor._local_tensor.size()) def _get_box_for(tensor: DTensor, idx: int) -> tuple[torch.Size, torch.Size]: offsets, size = _get_box(tensor) return (torch.Size([val * idx for val in offsets]), size) def _get_local_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: device_mesh = tensor.device_mesh coord = device_mesh.get_coordinate() assert coord is not None return _get_box_for(tensor, coord[0]) def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata: mesh = dt.device_mesh assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" offsets, sizes = _get_local_box(dt) return ShardMetadata( shard_offsets=list(offsets), shard_sizes=list(sizes), placement=f"rank:{current_rank}/{dt._local_tensor.device}", ) def _create_sharded_tensor_md_from_dt( dt: DTensor, dt_pg: c10d.ProcessGroup ) -> ShardedTensorMetadata: # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage # and yet has only one valid shard for the current rank. shards_md = [] my_rank = dist.get_rank(dt_pg) scapegoat_rank = 0 if my_rank > 0 else 1 if dt.placements[0].is_shard(): shard_count = dt_pg.size() else: shard_count = 1 for i in range(shard_count): offsets, sizes = _get_box_for(dt, i) shards_md.append( ShardMetadata( shard_offsets=list(offsets), shard_sizes=list(sizes), placement=( f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}" ), ) ) return ShardedTensorMetadata( shards_metadata=shards_md, size=dt.size(), tensor_properties=TensorProperties( dtype=dt.dtype, layout=dt.layout, requires_grad=dt.requires_grad, # ignore memory_format and pin_memory as those are not supported by DT ), ) def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup: mesh = dt.device_mesh assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" return mesh.get_group() def _rewrite_spec_if_needed( spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int ) -> shard_spec.ShardingSpec: """ Rewrite ``spec`` to match the device of ``tensor``. FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec produces CUDA metadata, ST construction bombs. """ if not isinstance(spec, ChunkShardingSpec): return spec # let's see if we need rewrite = False for p in spec.placements: p = cast(_remote_device, p) if p.rank() == rank and p.device() != tensor.device: rewrite = True break if rewrite: spec = copy.deepcopy(spec) for i, placement in enumerate(spec.placements): placement = cast(_remote_device, placement) if placement.rank() == rank and placement.device() != tensor.device: spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}") return spec def _chunk_tensor( tensor: torch.Tensor, rank: int, world_size: int, num_devices_per_node: int, pg: dist.ProcessGroup, ) -> torch.Tensor: if type(tensor) is ShardedTensor: assert len(tensor.local_shards()) == 1 inner_param = tensor.local_tensor() inner_st = _create_chunk_sharded_tensor( inner_param, rank, world_size, num_devices_per_node, pg, ) outer_local_shard = tensor.local_shards()[0] shards: list[Shard] = [ Shard(inner_st, copy.deepcopy(outer_local_shard.metadata)) ] st_meta = copy.deepcopy(tensor.metadata()) st_meta.tensor_properties.requires_grad = False st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( shards, sharded_tensor_metadata=st_meta, process_group=tensor._process_group, init_rrefs=False, ) return st_outer elif type(tensor) is DTensor: device_mesh = tensor.device_mesh assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" inner_param = tensor._local_tensor inner_st = _create_chunk_sharded_tensor( inner_param, rank, world_size, torch.accelerator.device_count(), pg, ) dt_pg = _get_dt_pg(tensor) # We do this differently here, we create a ST with no local shards then patch it shards = [ Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg))) ] st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg) st_meta.tensor_properties.requires_grad = False st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( shards, sharded_tensor_metadata=st_meta, process_group=dt_pg, init_rrefs=False, ) return st_outer else: return _create_chunk_sharded_tensor( tensor, rank, world_size, num_devices_per_node, pg, ) def _chunk_dtensor( tensor: torch.Tensor, rank: int, device_mesh: DeviceMesh, ) -> DTensor: """ Shard a tensor to chunks along the first dimension. The local rank will gets its corresponding chunk as the local tensor to create a DTensor. """ root_mesh = _mesh_resources.get_root_mesh(device_mesh) if root_mesh is None: raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") if root_mesh.ndim < 2: raise RuntimeError( f"Found parent device_mesh of ndim={root_mesh.ndim},", "but meshes must be at least 2D.", ) # We need to explicitly call .detach() to return a new tensor detached from the current graph. tensor = tensor.detach().clone() # When a layer is not involved in TP, then the tensor will not be a DTensor. # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. # TP is the inner dimension and FSDP is the outer dimension. # Therefore, shard placements for tensor is (Shard(0), Replicate()). replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] shard_placements = [Replicate() for _ in range(root_mesh.ndim)] shard_placements[0] = DShard(0) # type: ignore[call-overload] return DTensor.from_local( tensor, root_mesh, replicate_placements, run_check=False ).redistribute( device_mesh=root_mesh, placements=shard_placements, ) else: tp_placements = tensor.placements tp_placement = tp_placements[0] tensor = tensor.to_local() # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension. # TP is the inner dimension and FSDP is the outer dimension. # Therefore, shard placements for tensor is (Shard(0), tp_placement). # For higher dimensional meshes, it is replicated across other dimensions. For example, with # HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement). replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] replicate_placements[-1] = tp_placement # type: ignore[call-overload] shard_placements = [Replicate() for i in range(root_mesh.ndim)] # type: ignore[misc] shard_placements[-2] = DShard(0) # type: ignore[call-overload] shard_placements[-1] = tp_placement # type: ignore[call-overload] return DTensor.from_local( tensor, root_mesh, replicate_placements, run_check=False ).redistribute( device_mesh=root_mesh, placements=shard_placements, ) def _pre_load_state_dict( tensor: torch.Tensor, ) -> tuple[torch.Tensor, list[Shard]]: shards = cast(ShardedTensor, tensor).local_shards() if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor: inner_tensor = shards[0].tensor shards = inner_tensor.local_shards() # pyre-ignore[16] tensor = inner_tensor return (tensor, shards if len(shards) > 0 else []) def _all_gather_dtensor( tensor: DTensor, parent_mesh: Optional[DeviceMesh], ) -> torch.Tensor: """All gather a DTensor in its FSDP dimension and return the local tensor.""" assert parent_mesh == tensor.device_mesh placements = list(copy.deepcopy(tensor.placements)) # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] for i in range(0, len(placements) - 1): placements[i] = Replicate() tensor = tensor.redistribute( device_mesh=tensor.device_mesh, placements=placements, ) return tensor.to_local() class DTensorExtensions(FSDPExtensions): """ DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP. This is the implementation for FSDPExtensions defined in https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py """ def __init__(self, device_handle) -> None: super().__init__() self.compute_stream = None self.device_handle = device_handle # we have to use the dynamo disable this way to disable dynamo as the decorater way would # trigger build failure with torch deploy... self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] self.post_unflatten_transform ) def pre_flatten_transform( self, tensor: torch.Tensor, ) -> tuple[torch.Tensor, Optional[Any]]: return _flatten_tensor(tensor) def post_unflatten_transform( self, tensor: torch.Tensor, param_extension: Any ) -> torch.Tensor: stream = self.compute_stream or self.device_handle.current_stream() with self.device_handle.stream(stream): # runtime we put the unflattened tensor call on the compute stream since # the unflattened tensor might contain computations in fwd/bwd where we # need to sync properly. # TODO: this is a short term fix and we should make the get_unflat_views # directly happen in the compute stream. result = _unflatten_tensor( tensor, param_extension, device_handle=self.device_handle, compute_stream=self.compute_stream, ) _set_fsdp_flattened(result) return result def chunk_tensor( self, tensor: torch.Tensor, rank: int, world_size: int, num_devices_per_node: int, pg: dist.ProcessGroup, device: Optional[torch.device] = None, ) -> torch.Tensor: return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg) def chunk_dtensor( self, tensor: torch.Tensor, rank: int, device_mesh: DeviceMesh, ) -> torch.Tensor: return _chunk_dtensor(tensor, rank, device_mesh) def pre_load_state_dict_transform( self, tensor: torch.Tensor, ) -> tuple[torch.Tensor, list[Shard]]: return _pre_load_state_dict(tensor) def all_gather_dtensor( self, tensor: DTensor, parent_mesh: Optional[DeviceMesh], ) -> torch.Tensor: return _all_gather_dtensor(tensor, parent_mesh)