# Owner(s): ["module: dynamo"] import contextlib import copy import functools import random import unittest from contextlib import contextmanager from datetime import timedelta from io import StringIO from unittest.mock import patch import numpy as np import torch import torch._dynamo import torch._dynamo.logging import torch._dynamo.test_case import torch.distributed as dist import torch.optim as optim from torch import nn from torch._C import FileCheck from torch._dynamo import config from torch._dynamo.backends.distributed import DDPOptimizer from torch._dynamo.comptime import comptime from torch._dynamo.testing import collect_results from torch._dynamo.utils import same from torch._higher_order_ops.wrap import tag_activation_checkpoint from torch.distributed._functional_collectives import _maybe_wrap_tensor from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import ( lambda_auto_wrap_policy, transformer_auto_wrap_policy, ) from torch.nn.attention.flex_attention import flex_attention from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, ) from torch.testing._internal.common_distributed import ( _dynamo_dist_per_rank_init, DynamoDistributedMultiProcTestCase, DynamoDistributedSingleProcTestCase, import_transformers_or_skip, requires_nccl, skip_if_lt_x_gpu, ) from torch.testing._internal.common_utils import requires_cuda from torch.testing._internal.inductor_utils import HAS_GPU def reset_rng_state(): torch.manual_seed(1337) random.seed(1337) np.random.seed(1337) def init_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) class ToyModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): super().__init__() self.ctx_manager = ctx_manager self.net = nn.Sequential( *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] ) def forward(self, inputs): if self.ctx_manager is not None: with self.ctx_manager(): return self.net(inputs) else: return self.net(inputs) def get_model( device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None ): m = ToyModel( in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat, ctx_manager=ctx_manager, ).to(device) m.apply(init_weights) inputs = torch.rand(bsz, in_feat).to(device) outputs = m(inputs) return m, inputs, outputs class MutatingModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None): super().__init__() self.ctx_manager = ctx_manager self.net = nn.Sequential( *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] ) self.state = 1 def forward(self, inputs): self.state = 2 return self.net(inputs) * self.state def get_mutating_model( device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None ): m = MutatingModel( in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat, ctx_manager=ctx_manager, ).to(device) m.apply(init_weights) inputs = torch.rand(bsz, in_feat).to(device) outputs = m(inputs) return m, inputs, outputs class ForcedGetAttrMod(torch.nn.Module): def __init__(self, device): super().__init__() self.linear = torch.nn.Linear(1, 1) self.__dict__["forced_linear"] = torch.nn.Linear(1, 1).to(device=device) self.counter = 0 def forward(self, x): self.counter += 1 return x * self.linear(x) * self.forced_linear.weight def get_forced_getattr_module(device): mod = ForcedGetAttrMod(device).to(device=device) x = torch.randn(1, 1, device=device) return mod, x, mod(x) class ToyInnerModel(nn.Module): def __init__(self) -> None: super().__init__() self.layers = [nn.Linear(100, 100), nn.Linear(100, 100)] self.layers = nn.Sequential(*self.layers) def forward(self, inputs): return self.layers(inputs) class ToyOuterModel(nn.Module): def __init__(self, device): super().__init__() self.layers = [ToyInnerModel().to(device) for _ in range(2)] self.layers = nn.Sequential( self.layers[0], nn.ReLU(), self.layers[1], nn.ReLU() ) def forward(self, inputs): return self.layers(inputs) def get_toy_model_for_activation_checkpointing(device): m = ToyOuterModel(device).to(device) m.apply(init_weights) inputs = torch.rand(100, 100).to(device) return m, inputs def find_first_node(gm, func): for node in gm.graph.nodes: if node.target is func: return node return None def apply_fsdp_with_checkpointing( model, wrap_policy, checkpoint_policy, use_activation_checkpointing=True ): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl, ) model = FSDP( copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True ) if use_activation_checkpointing: checkpoint_wrapper_fn = functools.partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) apply_activation_checkpointing( model, checkpoint_wrapper_fn=checkpoint_wrapper_fn, check_fn=checkpoint_policy, ) return model def get_custom_model(device): class MyCustomLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.weight = nn.Parameter(torch.randn(512, 512)) def forward(self, x): tmp = torch.mm(x, self.weight.t()) # test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor) # and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation return tmp + torch.where(tmp < 0.5, 0.3, 0.6) class MyLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.Linear(512, 512) def forward(self, x): return self.linear(x) class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() mods = [ (MyLinear(), torch.nn.ReLU()), # sandwich the custom in the middle so it comes before and after (MyCustomLinear(), torch.nn.ReLU()), (MyLinear(), torch.nn.ReLU()), ] self.seq = torch.nn.Sequential(*[x for items in mods for x in items]) def forward(self, x, y): # test special case where the 0th bucket (layers close to graph input) is at capacity, which would # trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket. # optimize this case by fusing that 'empty bucket' back together with the previous full one return self.seq(x + y) m = MyModule().to(device) m.apply(init_weights) inputs = torch.rand((512, 512)).to(device) # test duplicated inputs inputs = (inputs, inputs) correct_outputs = m(*inputs) return m, inputs, correct_outputs def get_hf_bert(rank): # Note: use @import_transformers_or_skip on your test case if you use this # in a multiprocessing test try: from transformers import AutoModelForMaskedLM, BertConfig except ImportError as e: raise unittest.SkipTest("Unable to import transformers") from e batch_size, max_length, config, device = 4, 512, BertConfig(), f"cuda:{rank}" model = AutoModelForMaskedLM.from_config(config).to(device) input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device) decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to( device ) inputs = {"input_ids": input_ids, "labels": decoder_ids} model.train() return model, inputs class CheckSplitsCompiler: def __init__(self) -> None: self.compiler_called = 0 def compile_fn(self, gm, example_inputs): self.compiler_called += 1 return gm # This simulates DDP, but it doesn't actually do any process communication; # it just has enough properties so that the dynamo distributed optimization is # able to optimize. Feel free to simulate more properties as necessary. The # other important thing is patching _active_ddp_module, which is what actually # triggers DDP optimization class FakeDDP(nn.Module): def __init__(self, module, bucket_cap_mb=25): super().__init__() self.module = module self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024) @contextmanager def _inside_ddp_forward(self): DDP._active_ddp_module = self try: yield finally: DDP._active_ddp_module = None def forward(self, *inputs, **kwargs): if not DDP._active_ddp_module: with self._inside_ddp_forward(): return self.module.forward(*inputs, **kwargs) else: return self.module.forward(*inputs, **kwargs) def run_hf_bert_ddp(self, model, inputs, backend): reset_rng_state() correct_outputs = model(**inputs) correct_loss = correct_outputs.loss correct_loss.backward() reset_rng_state() opt_model = torch.compile(model, backend=backend) opt_outputs = opt_model(**inputs) opt_loss = opt_outputs.loss opt_loss.backward() inputs_flat = [inputs[k] for k in inputs] correct_results = collect_results( model, correct_outputs.logits, correct_loss, inputs_flat ) opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) self.assertTrue(same(correct_results, opt_results)) class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): model, inputs = get_hf_bert(0) model = FakeDDP(model) run_hf_bert_ddp(self, model, inputs, "inductor") @patch.object(config, "optimize_ddp", True) def test_hf_bert_ddp_aot_eager(self): model, inputs = get_hf_bert(0) model = FakeDDP(model) run_hf_bert_ddp(self, model, inputs, "aot_eager") @patch.object(config, "optimize_ddp", True) def test_issue90375(self): class Model(nn.Module): def forward(self): return torch.randn(3) * torch.randn(3) model = Model() model = FakeDDP(model) opt_model = torch.compile(model, backend="aot_eager") opt_model() @patch.object(config, "optimize_ddp", True) def test_symbol_splitting(self): class Model(nn.Module): def __init__(self) -> None: super().__init__() self.weight1 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512)) def forward(self, x): x = torch.cat([x, x]) y = x @ self.weight1 z = x + y @ self.weight2 return z model = Model() model = FakeDDP(model) opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512)) @patch.object(config, "optimize_ddp", True) def test_ddp_optimizer_inductor_strides_dont_specialize(self): class Model(nn.Module): def __init__(self): super().__init__() self.fc_0 = nn.Linear(768, 768) self.fc_1 = nn.Linear(768, 768) def forward(self, x): x = self.fc_0(x) x = self.fc_1(x) return x model = Model() model = FakeDDP(model) inp = torch.randn((16, 18, 768)) inp2 = torch.randn((16, 20, 768)) torch._dynamo.mark_dynamic(inp, 1) torch._dynamo.mark_dynamic(inp2, 1) torch._dynamo.utils.clear_compilation_metrics() torch._dynamo.reset() try: DDP._active_ddp_module = model opt_model = torch.compile(model) self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics())) opt_model(inp) compile_count_before = len(torch._dynamo.utils.get_compilation_metrics()) opt_model(inp2) compile_count_after = len(torch._dynamo.utils.get_compilation_metrics()) # no recompiles self.assertEqual(compile_count_before, compile_count_after) finally: DDP._active_ddp_module = None @config.patch(optimize_ddp=True, capture_scalar_outputs=True) def test_unbacked_symbol_splitting_direct(self): class Model(nn.Module): def __init__(self) -> None: super().__init__() self.weight1 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512)) def forward(self, x, y): u0, _ = y.tolist() x = torch.cat([x, x]) y = x @ self.weight1 z = (x + y @ self.weight2) * u0 return z model = Model() model = FakeDDP(model) opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) @config.patch(optimize_ddp=True, capture_scalar_outputs=True) def test_unbacked_symbol_splitting_indirect(self): class Model(nn.Module): def __init__(self) -> None: super().__init__() self.weight1 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512)) def forward(self, x, y): u0, _ = y.tolist() a = torch.ones(u0) x = torch.cat([x, x]) y = x @ self.weight1 z = (x + y @ self.weight2) * a.sum() return z model = Model() model = FakeDDP(model) opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) @config.patch(optimize_ddp=True, capture_scalar_outputs=True) def test_unbacked_symbol_splitting_torture_multi(self): class Model(nn.Module): def __init__(self) -> None: super().__init__() self.weight1 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512)) self.weight3 = nn.Parameter(torch.randn(512, 512)) def forward(self, x, y): # partition one (contains the u0 def) u0, _ = y.tolist() x = torch.cat([x, x]) y1 = x @ self.weight1 # partition two (contains the variable) y2 = y1 @ self.weight2 a = torch.ones(u0) # partition three z = (x + y2 @ self.weight3) * a.sum() return z model = Model() model = FakeDDP(model, bucket_cap_mb=1) opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([12, 13])) @config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True) def test_unbacked_symbol_splitting_no_binding(self): class Model(nn.Module): def __init__(self) -> None: super().__init__() self.weight1 = nn.Parameter(torch.randn(512, 512)) self.weight2 = nn.Parameter(torch.randn(512, 512)) def forward(self, x, y): nz = y.nonzero() x = torch.cat([x, x]) y = x @ self.weight1 z = (x + y @ self.weight2) * (nz + 1).sum() return z model = Model() model = FakeDDP(model) opt_model = torch.compile(dynamic=True)(model) opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0])) @patch.object(config, "optimize_ddp", True) def test_call_method_forward(self): class Model(nn.Module): def __init__( self, ): super().__init__() layers = [] for _ in range(2): layer = nn.ModuleList( [ nn.LayerNorm(96), nn.MultiheadAttention( embed_dim=96, num_heads=4, batch_first=True ), ] ) layers.append(layer) self.layers = nn.ModuleList(layers) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [Batch, Freq, Time, Feature] B, F, T, H = x.shape for m in self.layers: x = x.reshape(B * F, T, H) x = m[0](x) x, _ = m[1].forward(x, x, x) x = x.reshape(B, F, T, H) return x model = Model() model = FakeDDP(model) opt_model = torch.compile(model) opt_model(torch.randn(2, 129, 100, 96)) # Are these tests failing? Check and see if TestFakeDistributedSingleProc has a # single process version; if it's just a problem in the Dynamo distributed # optimizer, you should be able to repro it single process! @requires_nccl() class TestMultiProc(DynamoDistributedMultiProcTestCase): """ Note: MultiProcTestCase spawns processes per test and is slow. Prefer MultiThreadedTestCase for most tests. Perhaps use this one sparingly for integration tests. """ @skip_if_lt_x_gpu(2) @config.patch(optimize_ddp=False, enable_compiler_collectives=True) def test_ddp_baseline_aot_eager_multiprocess(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): self.assertFalse(config.optimize_ddp) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") m = DDP(m, device_ids=[self.rank]) m = torch.compile(m, backend="aot_eager") outputs = m(inputs) self.assertTrue(same(correct_outputs, outputs)) def _test_hf_bert_ddp_inductor(self, static_graph): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model, inputs = get_hf_bert(self.rank) model = DDP(model, static_graph=static_graph) run_hf_bert_ddp(self, model, inputs, "inductor") @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=True, enable_compiler_collectives=True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor(self): self._test_hf_bert_ddp_inductor(static_graph=False) @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=True, enable_compiler_collectives=True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp_inductor_static_graph(self): self._test_hf_bert_ddp_inductor(static_graph=True) def _test_hf_bert_aot_eager(self, static_graph): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model, inputs = get_hf_bert(self.rank) model = DDP(model, static_graph=static_graph) run_hf_bert_ddp(self, model, inputs, "aot_eager") @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @config.patch(optimize_ddp=True, enable_compiler_collectives=True) def test_hf_bert_ddp_aot_eager(self): self._test_hf_bert_aot_eager(static_graph=False) @skip_if_lt_x_gpu(2) @import_transformers_or_skip() @config.patch(optimize_ddp=True, enable_compiler_collectives=True) def test_hf_bert_ddp_aot_eager_static_graph(self): self._test_hf_bert_aot_eager(static_graph=True) @skip_if_lt_x_gpu(2) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(optimize_ddp=False, enable_compiler_collectives=True) def test_ddp_activation_checkpointing(self): from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, checkpoint_wrapper, CheckpointImpl, ) class MyModel(torch.nn.Module): def __init__(self) -> None: super().__init__() self.fc1 = torch.nn.Linear(64, 32) self.fc2 = torch.nn.Linear(32, 16) self.fc3 = torch.nn.Linear(16, 8) def forward(self, inp): return self.fc3(self.fc2(self.fc1(inp))) with _dynamo_dist_per_rank_init(self.rank, self.world_size): self.assertFalse(config.optimize_ddp) model = MyModel().to(device="cuda") # Activation checkpointing for Linear layers. non_reentrant_wrapper = functools.partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) check_fn = lambda submodule: isinstance( # noqa: E731 submodule, torch.nn.Linear ) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn ) model = DDP(model) x = torch.randn(10, 64).cuda() correct_outputs = model(x) opt_model = torch.compile(model) outputs = opt_model(x) self.assertTrue(same(correct_outputs, outputs)) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) def test_fsdp_aot_eager(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch.compile(fsdp_m, backend="aot_eager") outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) # Test with recursive wrapping, nested FSDP around each Linear m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP( m, auto_wrap_policy=functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,) ), use_orig_params=True, ) fsdp_m = torch.compile(fsdp_m, backend="aot_eager") outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) def test_fsdp_setattr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) from torch._dynamo.utils import counters counters.clear() m, inputs, correct_outputs = get_mutating_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) self.assertEqual(len(counters["graph_break"]), 1) first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015 self.assertIn("setattr() on Tensor.requires_grad", first_graph_break) @config.patch(inline_inbuilt_nn_modules=False) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) def test_fsdp_unspecialized_forced_getattr_no_inline(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) from torch._dynamo.utils import counters counters.clear() m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) def test_fsdp_unspecialized_forced_getattr_inline(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) from torch._dynamo.utils import counters counters.clear() m, inputs, correct_outputs = get_forced_getattr_module(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False) outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_inductor(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): # Test with basic FSDP wrapping (outer wrap around whole model) m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=True) fsdp_m = torch.compile(fsdp_m, backend="inductor") outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) # Test with recursive wrapping, nested FSDP around each Linear m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") fsdp_m = FSDP( m, auto_wrap_policy=functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,) ), use_orig_params=True, ) fsdp_m = torch.compile(fsdp_m, backend="inductor") outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @config.patch(enable_compiler_collectives=True) @skip_if_lt_x_gpu(1) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_fsdp_activation_checkpointing(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): model, inputs = get_toy_model_for_activation_checkpointing( f"cuda:{self.rank}" ) is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731 wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner) model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner) correct_outputs = model(inputs) cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor") opt_model = torch.compile(model, backend=cnt) outputs = opt_model(inputs) self.assertTrue(same(correct_outputs, outputs)) # Each FSDP module is a separate graph self.assertEqual(cnt.frame_count, 2) self.assertTrue( find_first_node(cnt.graphs[0], tag_activation_checkpoint) is not None ) @import_transformers_or_skip() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) @config.patch(enable_compiler_collectives=True) @unittest.skipIf( PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Inaccurate results with fused SDPA kernels", ) def test_hf_bert_fsdp(self): def apply_fsdp(model, wrap_policy): model = FSDP( copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True ) return model with _dynamo_dist_per_rank_init(self.rank, self.world_size): for wrap_policy, test_instance in ( (None, "FSDP without recursive wrapping"), ): print(f"Running hf_bert test for {test_instance}") model, inputs = get_hf_bert(self.rank) reset_rng_state() eager_model = apply_fsdp(model, wrap_policy) correct_outputs = eager_model(**inputs) correct_loss = correct_outputs.loss correct_loss.backward() reset_rng_state() opt_model = apply_fsdp(model, wrap_policy) opt_model = torch.compile(opt_model, backend="inductor") opt_outputs = opt_model(**inputs) opt_loss = opt_outputs.loss opt_loss.backward() inputs_flat = [inputs[k] for k in inputs] correct_results = collect_results( eager_model, correct_outputs.logits, correct_loss, inputs_flat ) opt_results = collect_results( opt_model, opt_outputs.logits, opt_loss, inputs_flat ) self.assertTrue(same(correct_results, opt_results)) @import_transformers_or_skip() @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) @config.patch(guard_nn_modules=True, enable_compiler_collectives=True) def test_hf_bert_fsdp_activation_checkpointing(self): from transformers.models.bert.modeling_bert import BertLayer with _dynamo_dist_per_rank_init(self.rank, self.world_size): for wrap_policy, test_instance in ( ( functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer,) ), "FSDP with recursive wrapping BertLayer instances", ), ): print( f"Running hf_bert_activation_checkpointing test for {test_instance}" ) model, inputs = get_hf_bert(self.rank) check_fn = lambda submodule: isinstance( # noqa: E731 submodule, BertLayer ) reset_rng_state() eager_model = apply_fsdp_with_checkpointing( model, wrap_policy, check_fn ) correct_outputs = eager_model(**inputs) correct_loss = correct_outputs.loss correct_loss.backward() reset_rng_state() opt_model = apply_fsdp_with_checkpointing(model, wrap_policy, check_fn) opt_model = torch.compile(opt_model, backend="inductor") opt_outputs = opt_model(**inputs) opt_loss = opt_outputs.loss opt_loss.backward() inputs_flat = [inputs[k] for k in inputs] correct_results = collect_results( eager_model, correct_outputs.logits, correct_loss, inputs_flat ) opt_results = collect_results( opt_model, opt_outputs.logits, opt_loss, inputs_flat ) self.assertTrue(same(correct_results, opt_results)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_tensor(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): class SimpleModel(nn.Module): def __init__(self, input_size, output_size): super().__init__() self.linear = nn.Linear(input_size, output_size) def forward(self, x): return self.linear(x) torch._dynamo.utils.clear_compilation_metrics() model = SimpleModel(10, 2).to(self.rank) model.forward = torch.compile(model.forward) ddp_model = DDP(model, device_ids=[self.rank]) loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) def B(s): return [torch.randn(s, 10), torch.randint(0, 2, (s,))] if self.rank == 0: dataloader = [B(5), B(8), B(6)] else: dataloader = [B(6), B(6), B(3)] for data, labels in dataloader: data, labels = data.to(self.rank), labels.to(self.rank) optimizer.zero_grad() output = ddp_model(data) loss = loss_fn(output, labels) loss.backward() optimizer.step() metrics = torch._dynamo.utils.get_compilation_metrics() # Number of compiles same on all nodes res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_scalar(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() # TODO: This should be possible to do inside the function, but device = f"cuda:{self.rank}" @torch.compile() def f(x, y): return x + torch.ones(y, device=device).sum() if self.rank == 0: dataloader = [3, 3, 7] else: dataloader = [3, 4, 9] for data in dataloader: f(torch.randn(5, device=self.rank), data) metrics = torch._dynamo.utils.get_compilation_metrics() # Number of compiles same on all nodes res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_automatic_dynamic_speculation_divergence(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() @torch.compile() def f(x, y): zx = x.shape # noqa: F841 zy = y.shape # noqa: F841 return x.sum() + y.sum() if self.rank == 0: dataloader = [4, 4] else: dataloader = [3, 4] for data in dataloader: f( torch.randn(data, device=self.rank), torch.randn(data, device=self.rank), ) metrics = torch._dynamo.utils.get_compilation_metrics() # Number of compiles same on all nodes res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_graph_break_empty_graph_still_collective(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() @torch.compile() def f(x, y): z = y # noqa: F841 print("woof") zx = x.shape # noqa: F841 zy = y.shape # noqa: F841 return x.sum() + y.sum() if self.rank == 0: dataloader = [5, 5, 6] else: dataloader = [3, 4, 5] for data in dataloader: f( torch.randn(data, device=self.rank), torch.randn(data, device=self.rank), ) metrics = torch._dynamo.utils.get_compilation_metrics() # Number of compiles same on all nodes res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_dim_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() @torch.compile() def f(x, y): zx = x.shape # noqa: F841 zy = y.shape # noqa: F841 return x.sum() + y.sum() if self.rank == 0: dataloader = [[4, 2]] else: dataloader = [[3]] for data in dataloader: f( torch.randn(data, device=self.rank), torch.randn(data, device=self.rank), ) metrics = torch._dynamo.utils.get_compilation_metrics() res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_missing_source(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() @torch.compile() def f(rank, xs): return xs[rank].sum() xs = [] for _ in range(self.world_size): xs.append(torch.randn(10, device=self.rank)) f(self.rank, xs) metrics = torch._dynamo.utils.get_compilation_metrics() res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_scalar_missing_source(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() @torch.compile() def f(rank, xs): return torch.tensor(xs[rank], device=self.rank) xs = [] for i in range(self.world_size): xs.append(10 + i) f(self.rank, xs) metrics = torch._dynamo.utils.get_compilation_metrics() res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @config.patch(enable_compiler_collectives=True) def test_compiler_collectives_type_mismatch(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() @torch.compile() def f(x): if isinstance(x, int): return torch.tensor(x, device=self.rank) else: return x.sum() if self.rank == 0: x = torch.randn(10, device=self.rank) else: x = 12 f(x) # This deadlocks, I guess we don't support this """ if self.rank == 0: x = torch.randn(12, device=self.rank) else: x = 10 f(x) """ metrics = torch._dynamo.utils.get_compilation_metrics() res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_get_pg_attr(self): with _dynamo_dist_per_rank_init(self.rank, self.world_size): pg = dist.distributed_c10d._get_default_group() device = f"cuda:{self.rank}" @torch.compile(fullgraph=True) def f(x): if dist.distributed_c10d._rank_not_in_group(pg): return x + 1 else: return x - 1 x = torch.ones(4, device=device) self.assertEqual(f(x), x - 1) pg = dist.distributed_c10d.GroupMember.NON_GROUP_MEMBER self.assertEqual(f(x), x + 1) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", False) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) def test_asymmetric_compilation(self): from torch._dynamo.comptime import comptime with _dynamo_dist_per_rank_init(self.rank, self.world_size): torch._dynamo.utils.clear_compilation_metrics() device = f"cuda:{self.rank}" pg = dist.distributed_c10d._get_default_group() cnt = torch._dynamo.testing.CompileCounter() sleep_time = 5 @torch.compile(backend=cnt) def f(x): if self.rank == 0: comptime.sleep(sleep_time) y = 2 * x return y.sum() backend = pg._get_backend(torch.device(device)) backend._set_default_timeout(timedelta(seconds=sleep_time - 2)) x = torch.ones(4, device=device) # NCCL startup is lazy w = pg.allreduce(x) w.wait() f(x) if self.rank != 0: # test fails with NCCL timeout without this line dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs( timedelta(seconds=sleep_time) ) w = pg.allreduce(x) w.wait() torch.cuda.synchronize(device) metrics = torch._dynamo.utils.get_compilation_metrics() # Number of compiles same on all nodes res = [None] * self.world_size torch.distributed.all_gather_object(res, len(metrics)) for r in res[1:]: self.assertEqual(res[0], r) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(torch._inductor.config, "fx_graph_cache", True) @patch.object(torch._inductor.config, "fx_graph_remote_cache", False) @patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10) def test_asymmetric_compilation_with_fx_cache(self): from torch._dynamo.utils import counters from torch._inductor.utils import fresh_inductor_cache with fresh_inductor_cache(), _dynamo_dist_per_rank_init( self.rank, self.world_size ): torch._dynamo.utils.clear_compilation_metrics() device = f"cuda:{self.rank}" pg = dist.distributed_c10d._get_default_group() @torch.compile def f(x): y = 2 * x return y.sum() backend = pg._get_backend(torch.device(device)) backend._set_default_timeout(timedelta(seconds=5)) counters.clear() x = torch.ones(4, device=device) f(x) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) w = pg.allreduce(x) w.wait() torch.cuda.synchronize(device) torch._dynamo.reset() if self.rank == 0: with fresh_inductor_cache(): f(x) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) else: f(x) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) w = pg.allreduce(x) w.wait() torch.cuda.synchronize(device) @requires_nccl() @requires_cuda class TestSingleProc(DynamoDistributedSingleProcTestCase): """ Test harness initializes dist process group. Test simple things here since they are simpler to debug. Use TestMultiProc for things that really need to run on multiple nodes """ def get_model( self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None ): m = ToyModel( in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat, ctx_manager=ctx_manager, ).to(self.device) m.apply(init_weights) inputs = torch.rand(bsz, in_feat).to(self.device) outputs = m(inputs) return m, inputs, outputs @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_aot_eager(self): from torch.nn.parallel import DistributedDataParallel as DDP m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids) ddp_m = torch.compile(ddp_m, backend="aot_eager") outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", False) def test_ddp_baseline_inductor(self): from torch.nn.parallel import DistributedDataParallel as DDP m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids) ddp_m = torch.compile(ddp_m, backend="inductor") outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) @patch.object(config, "optimize_ddp", True) def test_graph_split(self): assert config.optimize_ddp """ Just ensures that the appropriate number of splits happen (based on bucket size and model parameters) - verifies the number of times the user-provided compiler is called by the DDPOptimizer which is doing the graph splitting """ m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) check_splits_compiler = CheckSplitsCompiler() @torch.compile(backend=check_splits_compiler.compile_fn) def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) # ensure compatibility with dynamo explain explain_out = torch._dynamo.explain(ddp_m)(inputs) break_reasons = explain_out.break_reasons self.assertEqual(len(break_reasons), 3) self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) @patch.object(config, "optimize_ddp", True) def test_graph_split_ctx_manager(self): """ Ensures that we get the right number of splits and that the respective context managers' effects are applied to the computation. """ for get_compiler in [ lambda: CheckSplitsCompiler(), lambda: None, ]: for ctx_manager, output_test in [ ( lambda: torch.autocast( torch.device(self.device).type, torch.float16 ), lambda out: self.assertEqual(out.dtype, torch.float16), ), (torch.enable_grad, lambda out: self.assertTrue(out.requires_grad)), (torch.no_grad, lambda out: self.assertTrue(not out.requires_grad)), ]: m, inputs, correct_outputs = self.get_model( out_feat=1000, hidden_feat=1000, in_feat=1000, ctx_manager=ctx_manager, ) # inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB # hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB bucket_cap_mb = 3.5 # 4MB ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb) compiler = get_compiler() @torch.compile(backend=compiler.compile_fn if compiler else "aot_eager") def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) if compiler: self.assertEqual(compiler.compiler_called, 4) output_test(opt_outputs) # ensure compatibility with dynamo explain explain_out = torch._dynamo.explain(ddp_m)(inputs) break_reasons = explain_out.break_reasons self.assertEqual(len(break_reasons), 4) self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) @patch.object(config, "optimize_ddp", True) def test_compiled_flex_attention_full_model_ddp(self): class Model(torch.nn.Module): def __init__(self, S, H, D): super().__init__() self.S = S self.H = H self.D = D alibi_bias = self.generate_alibi_bias(H) self.register_buffer("alibi_bias", alibi_bias, persistent=True) self.attention = flex_attention self.project_qk = torch.nn.Linear(H * D, H * D * 2) self.project_v = torch.nn.Linear(H * D, H * D) def forward(self, hidden_states): batch_size, _, _ = hidden_states.size() query, key = self.project_qk(hidden_states).chunk(2, dim=2) query = query.view(self.S, batch_size, self.H, self.D) query = query.permute(1, 2, 0, 3) key = key.view(self.S, batch_size, self.H, self.D) key = key.permute(1, 2, 0, 3) value = self.project_v(hidden_states) value = value.view(self.S, batch_size, self.H, self.D) value = value.permute(1, 2, 0, 3) return self.attention(query, key, value, score_mod=self.alibi_score_mod) def generate_alibi_bias(self, num_heads): alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)] return torch.tensor(alibi_bias) def alibi_score_mod(self, score, b, h, q_idx, kv_idx): bias = (q_idx - kv_idx) * self.alibi_bias[h] return score + bias B = 16 H = 12 S = 512 D = 64 device = "cuda" model = Model(S, H, D) model.to(device) model = torch.compile(model) model = DDP(model, device_ids=self.device_ids) hidden_states = torch.randn(B, S, H * D).to(device) model(hidden_states) torch.cuda.synchronize() @patch.object(config, "optimize_ddp", True) def test_compiled_flex_attention_local_ddp(self): class Model(torch.nn.Module): def __init__(self, S, H, D): super().__init__() self.S = S self.H = H self.D = D alibi_bias = self.generate_alibi_bias(H) self.register_buffer("alibi_bias", alibi_bias, persistent=True) self.attention = torch.compile(flex_attention) self.project_qk = torch.nn.Linear(H * D, H * D * 2) self.project_v = torch.nn.Linear(H * D, H * D) def forward(self, hidden_states): batch_size, _, _ = hidden_states.size() query, key = self.project_qk(hidden_states).chunk(2, dim=2) query = query.view(self.S, batch_size, self.H, self.D) query = query.permute(1, 2, 0, 3) key = key.view(self.S, batch_size, self.H, self.D) key = key.permute(1, 2, 0, 3) value = self.project_v(hidden_states) value = value.view(self.S, batch_size, self.H, self.D) value = value.permute(1, 2, 0, 3) return self.attention(query, key, value, score_mod=self.alibi_score_mod) def generate_alibi_bias(self, num_heads): alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)] return torch.tensor(alibi_bias) def alibi_score_mod(self, score, b, h, q_idx, kv_idx): bias = (q_idx - kv_idx) * self.alibi_bias[h] return score + bias B = 16 H = 12 S = 512 D = 64 device = "cuda" model = Model(S, H, D) model.to(device) model = torch.compile(model) model = DDP(model, device_ids=self.device_ids) hidden_states = torch.randn(B, S, H * D).to(device) model(hidden_states) torch.cuda.synchronize() @patch.object(config, "optimize_ddp", True) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor(self): assert config.optimize_ddp """ Same as above, but using inductor backend. We observed issues with inductor/fx interface in the past. """ m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @torch.compile(backend="inductor") def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) @torch._inductor.config.patch( {"layout_optimization": True, "keep_output_stride": False} ) @patch.object(config, "optimize_ddp", True) def _test_graph_split_inductor_layout_optimizations_impl(self, context): assert config.optimize_ddp channel_dim = 512 # channel dim must be > 64 for inductor to do layout optimization and use NHWC class ToyModelConv(nn.Module): def __init__(self) -> None: super().__init__() self.net = nn.Sequential( *[ nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU(), ] + [ nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU(), ] + [ nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU(), ] + [ nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False), nn.ReLU(), ] ) def forward(self, inputs): return self.net(inputs) def get_model(): m = ToyModelConv().to(self.device) m.apply(init_weights) inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device) outputs = m(inputs) return m, inputs, outputs with context(): m, inputs, correct_outputs = get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @torch.compile(backend="inductor") def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_layout_optimizations_training(self): self._test_graph_split_inductor_layout_optimizations_impl( contextlib.nullcontext ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_layout_optimizations_inference(self): self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad) @patch.object(config, "optimize_ddp", True) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor_transpose(self): assert config.optimize_ddp B = 100 N = 30 D = 50 K = 70 class Foo(nn.Module): def __init__(self) -> None: super().__init__() self.linear0 = nn.Linear(N, K) self.linear1 = torch.nn.Linear(D * K, 2048) def forward(self, x): xt = x.transpose(2, 1) xt = self.linear0(xt).flatten(1) return self.linear1(xt) mod = Foo().to(self.device) compiled_mod = torch.compile(mod, backend="inductor") ddp_compiled_mod = DDP(compiled_mod, device_ids=self.device_ids) x = torch.randn((B, N, D), dtype=torch.float32, device=self.device) self.assertTrue(same(mod(x), ddp_compiled_mod(x))) x_1 = torch.randn((B * 2, N, D), dtype=torch.float32, device=self.device) self.assertTrue(same(mod(x_1), ddp_compiled_mod(x_1))) x_2 = torch.randn((B * 3, N, D), dtype=torch.float32, device=self.device) self.assertTrue(same(mod(x_2), ddp_compiled_mod(x_2))) @patch.object(config, "optimize_ddp", True) def test_no_split(self): """ Ensures the DDPOptimizer returns a correct, compiled module without introducing graph splits. (Based on model parameters fitting in the bucket) """ # DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this m, inputs, correct_outputs = self.get_model(hidden_feat=5) ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250) check_splits_compiler = CheckSplitsCompiler() @torch.compile(backend=check_splits_compiler.compile_fn) def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 1) @patch.object(config, "optimize_ddp", True) def test_aot_autograd(self): """ Explicitly check AotAutograd family of compilers work, since they require example inputs propagated between graph splits. """ m, inputs, correct_outputs = self.get_model() ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) @torch.compile(backend="aot_eager") def opt_fn(inputs): return ddp_m(inputs) opt_outputs = opt_fn(inputs) opt_outputs.sum().backward() self.assertTrue(same(correct_outputs, opt_outputs)) @patch.object(config, "optimize_ddp", True) def test_custom_layer(self): """ Just ensures that the appropriate number of splits happen (based on bucket size and model parameters) - verifies the number of times the user-provided compiler is called by the DDPOptimizer which is doing the graph splitting """ m, inputs, correct_outputs = get_custom_model(self.device) ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1) check_splits_compiler = CheckSplitsCompiler() @torch.compile(backend=check_splits_compiler.compile_fn) def opt_fn(inputs): return ddp_m(*inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 3) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_empty_graph_inductor(self): def fn(): get_world_size = torch.distributed.distributed_c10d.get_world_size() return (get_world_size,) opt_fn = torch.compile(fn, backend="inductor") res = None try: res = opt_fn()[0] except Exception: pass self.assertEqual(res, 1) @patch.object(config, "optimize_ddp", False) def test_ignored_parameters(self): """ Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module. Hooks up graph-split optimizer manually so it can peek at internal state. """ m, inputs, correct_outputs = get_custom_model(self.device) parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"] DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore) ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25) parameter_ids_to_ignore = [ id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore ] check_splits_compiler = CheckSplitsCompiler() ddp_optimizer = DDPOptimizer( bucket_bytes_cap=ddp_m.bucket_bytes_cap, backend_compile_fn=check_splits_compiler.compile_fn, ) @torch.compile(backend=ddp_optimizer.compile_fn) def opt_fn(inputs): return ddp_m(*inputs) opt_outputs = opt_fn(inputs) self.assertTrue(same(correct_outputs, opt_outputs)) self.assertEqual(check_splits_compiler.compiler_called, 2) for b in ddp_optimizer.buckets: for p_id in b.param_ids: self.assertFalse(p_id in parameter_ids_to_ignore) @patch.object(config, "optimize_ddp", True) def test_higher_order_op(self): from torch.utils.checkpoint import checkpoint N = 1000 class InnerModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear1 = torch.nn.Linear(N, N) self.linear2 = torch.nn.Linear(N, N) def forward(self, x): a = self.linear1(x) a = self.linear2(a) return a class MockModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.inner_mod1 = InnerModule() self.inner_mod2 = InnerModule() def forward(self, x): a = checkpoint(self.inner_mod1, x, use_reentrant=False) a = torch.cos(a) a = checkpoint(self.inner_mod2, a, use_reentrant=False) a = torch.cos(a) return a mod = MockModule().cuda() mod = DDP(mod, bucket_cap_mb=1) x = torch.randn(N, N, device="cuda", requires_grad=True) args = (x,) backend = "aot_eager" cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) torch.compile(mod, backend=cnt)(*args) def test_fsdp_orig_params_assert(self): # Test with basic FSDP wrapping (outer wrap around whole model) m, inputs, _ = get_model(f"cuda:{self.rank}") fsdp_m = FSDP(m, use_orig_params=False) # Test is that this function call does not throw an exception. fsdp_m = torch.compile(fsdp_m) def test_fsdp_skip_guards(self): """ It's currently difficult to test dynamo guards. Most guards tests are indirect- modify something and observe that the guard in question failed. In this case, since the FSDP guards were already deemed useless and skipping them is expected to have no practical effect, it's pretty contrived to even try to make those guards fail. Instead, we observe the 'guard source' printed by dynamo's comptime print_guards function. Note: comptime prints the guards before the time they get installed or not installed, so in both cases (skip or no skip) the same guards get printed. The difference is that in the skip case, they show up with a special 'guard source' which will cuase them to not be installed. So all we check for is the expected guard source 'local_fsdp_module'. """ global GUARDS_FILE GUARDS_FILE = StringIO() for skip_guards, expected_guard_source in ( (True, "local_fsdp_module"), (False, "local_unspecialized_nn_module"), ): torch._dynamo.reset() class ToyModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5): super().__init__() self.net = nn.Sequential( *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, out_feat), nn.ReLU()] ) def forward(self, inputs): out = self.net(inputs) @comptime def _(ctx): ctx.print_guards(file=GUARDS_FILE) return out device = f"cuda:{self.rank}" m = ToyModel( in_feat=10, hidden_feat=5000, out_feat=5, ).to(device) inputs = torch.rand(20, 10).to(device) m.apply(init_weights) correct_outputs = m(inputs) fsdp_m = FSDP(m, use_orig_params=True) with torch._dynamo.config.patch(skip_fsdp_guards=skip_guards): opt_m = torch.compile(fsdp_m, backend="aot_eager") outputs = opt_m(inputs) # far from an exhaustive check of all the expected guards, just check a couple of them. FileCheck().check("""local "L['self']" TYPE_MATCH""").check( f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" ).run( GUARDS_FILE.getvalue() ) self.assertTrue(same(correct_outputs, outputs)) def test_fsdp_skip_register_attr_or_module(self): """ ensure FSDP module is not registered as attrbutes in the fx graph see `not source.guard_source().is_fsdp_module()` before calling `register_attr_or_module` in variables/builder.py """ class ToyModel(nn.Module): def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5): super().__init__() self.net = nn.Sequential( *[nn.Linear(in_feat, hidden_feat), nn.ReLU()] + [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()] ) def forward(self, inputs): out = self.net(inputs) return out torch._dynamo.reset() device = f"cuda:{self.rank}" m = ToyModel( in_feat=10, hidden_feat=5000, out_feat=5, ).to(device) inputs = torch.rand(20, 10).to(device) m.apply(init_weights) correct_outputs = m(inputs) fsdp_m = FSDP(m, use_orig_params=True) def debug_compiler(gm, _): for node in gm.graph.nodes: if node.op == "get_attr": for name in [ "l__self___net_0_weight", "l__self___net_0_bias", "l__self___net_2_weight", "l__self___net_2_bias", ]: self.assertFalse( name in node.name, f"FSDP module {name} should not be registered as attributes", ) return gm opt_m = torch.compile(fsdp_m, backend=debug_compiler) outputs = opt_m(inputs) self.assertTrue(same(correct_outputs, outputs)) def test_fsdp_dup_tensors_same_source(self): """ Tests that FSDP-managed modules' parameters and buffers with the same source are de-duplicated, meaning that they are each only passed once as a graph input. """ class DuplicateModule(nn.Module): def __init__(self) -> None: super().__init__() self._param = torch.randn((3,), device="cuda") self._buf = torch.nn.Buffer( torch.randn((3,), requires_grad=False, device="cuda") ) def forward(self, x: torch.Tensor) -> torch.Tensor: # Use `_param` and `_buf` each twice in this compiled forward # to exercise if they are de-duplicated by TorchDynamo z = x + self._buf + self._buf z += self._param + self._param return z model = DuplicateModule() fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True) fsdp_model = torch.compile(fsdp_model, backend="aot_eager") inp = torch.randn((2, 3), device="cuda") local_out = model(inp) fsdp_out = fsdp_model(inp) self.assertEqual(local_out, fsdp_out) @patch.object(config, "guard_nn_modules", True) def test_fsdp_dup_tensors_diff_source(self): """ Tests that FSDP-managed modules' parameters and buffers with different source do not result in incorrect AOTAutograd de-dup guards like ``a is b``, where ``a`` and ``b`` are certainly not the same. We check this by checking for per-invocation recompiles. """ class BufModule(nn.Module): def __init__(self) -> None: super().__init__() self._buf = nn.Buffer( torch.randn((3,), requires_grad=False, device="cuda") ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self._buf class Model(nn.Module): def __init__(self) -> None: super().__init__() self._param = nn.Parameter(torch.randn((1,), device="cuda")) self._buf_module = BufModule() # Share the buffer, meaning same tensor but different source self._buf = self._buf_module._buf def forward(self, x: torch.Tensor) -> torch.Tensor: # Use the same buffer tensor twice in the compiled forward, # including a data mutation to trigger de-dup logic self._buf.mul_(2) z = x + self._buf z = self._buf_module(z) z += self._param return z fsdp_model = FSDP(Model(), use_orig_params=True) cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") fsdp_model = torch.compile(fsdp_model, backend=cnt) inp = torch.randn((2, 3), device="cuda") for _ in range(15): fsdp_model(inp) # Check for no recompiles (if there were incorrect de-dup guards, then # the frame count would be equal to the number of forward calls) self.assertEqual(cnt.frame_count, 1) def test_fsdp_staticmethod(self): """ Tests that Dynamo compiles staticmethods for FSDP-managed modules correctly both when the staticmethod is invoked from the class and from the object itself. """ class ModuleWithStaticMethod(nn.Module): def __init__(self, use_self: bool): super().__init__() self._use_self = use_self torch.manual_seed(42) # force `_param` to be deterministic self._param = nn.Parameter(torch.randn((3,), device="cuda")) def forward(self, x: torch.Tensor) -> torch.Tensor: if self._use_self: z = self._add(x, self._param) else: z = ModuleWithStaticMethod._add(x, self._param) z *= 2 return z @staticmethod def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y model = ModuleWithStaticMethod(False) x = torch.randn((2, 3), device="cuda") ref_out = model(x) test_outs: list[torch.Tensor] = [] for use_self in (False, True): model = ModuleWithStaticMethod(use_self) fsdp_model = FSDP(model, use_orig_params=True) cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") fsdp_model = torch.compile(fsdp_model, backend=cnt) test_outs.append(fsdp_model(x)) # Check for no recompiles, which could happen if incorrectly # passing args to the staticmethod (e.g. doubly passing `self`) # 3 is expected here for 1 forward. # Graph 1 should be add and imul self.assertEqual(cnt.frame_count, 1) for test_out in test_outs: self.assertEqual(test_out, ref_out) def test_async_subclass_no_specialize(self): cnt = torch._dynamo.testing.CompileCounterWithBackend("eager") @torch.compile(backend=cnt, fullgraph=True, dynamic=True) def f(x): return x + 1 f(_maybe_wrap_tensor(torch.randn(10))) f(_maybe_wrap_tensor(torch.randn(12))) self.assertEqual(cnt.frame_count, 1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()