From e5b1254baaf94cddbccee01246bdeb0496a8c5b5 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Tue, 24 Jun 2025 14:04:40 +0100 Subject: [PATCH 01/20] Stop parsing command line arguments every time common_utils is imported. --- torch/testing/_internal/common_utils.py | 189 +++++++++++++++--------- 1 file changed, 120 insertions(+), 69 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 45b7378f88cc..15034f543503 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -104,6 +104,32 @@ MI300_ARCH = ("gfx942",) +class ProfilingMode(Enum): + LEGACY = 1 + SIMPLE = 2 + PROFILING = 3 + +# Set by parse_cmd_line_args() if called +CI_FUNCTORCH_ROOT = "" +CI_PT_ROOT = "" +CI_TEST_PREFIX = "" +DISABLED_TESTS_FILE = "" +GRAPH_EXECUTOR : ProfilingMode | None = None +LOG_SUFFIX = "" +PYTEST_SINGLE_TEST = "" +REPEAT_COUNT = 0 +RERUN_DISABLED_TESTS = False +RUN_PARALLEL = 0 +SEED = 0 +SHOWLOCALS = False +SLOW_TESTS_FILE = "" +TEST_BAILOUTS = False +TEST_DISCOVER = False +TEST_IN_SUBPROCESS = False +TEST_SAVE_XML = "" +UNITTEST_ARGS : list[str] = [] +USE_PYTEST = False + def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -838,11 +864,6 @@ def test_wrapper(*args, **kwargs): yield (test_wrapper, test_name, {}, decorator_fn) -class ProfilingMode(Enum): - LEGACY = 1 - SIMPLE = 2 - PROFILING = 3 - def cppProfilingFlagsToProfilingMode(): old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -861,6 +882,7 @@ def cppProfilingFlagsToProfilingMode(): def enable_profiling_mode_for_profiling_tests(): old_prof_exec_state = False old_prof_mode_state = False + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -895,6 +917,7 @@ def num_profiled_runs(num_runs): def prof_callable(callable, *args, **kwargs): if 'profile_and_replay' in kwargs: del kwargs['profile_and_replay'] + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: with enable_profiling_mode_for_profiling_tests(): callable(*args, **kwargs) @@ -924,72 +947,93 @@ def _get_test_report_path(): test_source = override if override is not None else 'python-unittest' return os.path.join('test-reports', test_source) -is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") -parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) -parser.add_argument('--subprocess', action='store_true', - help='whether to run each test in a subprocess') -parser.add_argument('--seed', type=int, default=1234) -parser.add_argument('--accept', action='store_true') -parser.add_argument('--jit-executor', '--jit_executor', type=str) -parser.add_argument('--repeat', type=int, default=1) -parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') -parser.add_argument('--use-pytest', action='store_true') -parser.add_argument('--save-xml', nargs='?', type=str, - const=_get_test_report_path(), - default=_get_test_report_path() if IS_CI else None) -parser.add_argument('--discover-tests', action='store_true') -parser.add_argument('--log-suffix', type=str, default="") -parser.add_argument('--run-parallel', type=int, default=1) -parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) -parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) -parser.add_argument('--rerun-disabled-tests', action='store_true') -parser.add_argument('--pytest-single-test', type=str, nargs=1) -parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) +def parse_cmd_line_args(): + global CI_FUNCTORCH_ROOT + global CI_PT_ROOT + global CI_TEST_PREFIX + global DISABLED_TESTS_FILE + global GRAPH_EXECUTOR + global LOG_SUFFIX + global PYTEST_SINGLE_TEST + global REPEAT_COUNT + global RERUN_DISABLED_TESTS + global RUN_PARALLEL + global SEED + global SHOWLOCALS + global SLOW_TESTS_FILE + global TEST_BAILOUTS + global TEST_DISCOVER + global TEST_IN_SUBPROCESS + global TEST_SAVE_XML + global UNITTEST_ARGS + global USE_PYTEST + + is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") + parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) + parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--accept', action='store_true') + parser.add_argument('--jit-executor', '--jit_executor', type=str) + parser.add_argument('--repeat', type=int, default=1) + parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') + parser.add_argument('--use-pytest', action='store_true') + parser.add_argument('--save-xml', nargs='?', type=str, + const=_get_test_report_path(), + default=_get_test_report_path() if IS_CI else None) + parser.add_argument('--discover-tests', action='store_true') + parser.add_argument('--log-suffix', type=str, default="") + parser.add_argument('--run-parallel', type=int, default=1) + parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) + parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) + parser.add_argument('--rerun-disabled-tests', action='store_true') + parser.add_argument('--pytest-single-test', type=str, nargs=1) + parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) # Only run when -h or --help flag is active to display both unittest and parser help messages. -def run_unittest_help(argv): - unittest.main(argv=argv) - -if '-h' in sys.argv or '--help' in sys.argv: - help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) - help_thread.start() - help_thread.join() - -args, remaining = parser.parse_known_args() -if args.jit_executor == 'legacy': - GRAPH_EXECUTOR = ProfilingMode.LEGACY -elif args.jit_executor == 'profiling': - GRAPH_EXECUTOR = ProfilingMode.PROFILING -elif args.jit_executor == 'simple': - GRAPH_EXECUTOR = ProfilingMode.SIMPLE -else: - # infer flags based on the default settings - GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() - -RERUN_DISABLED_TESTS = args.rerun_disabled_tests - -SLOW_TESTS_FILE = args.import_slow_tests -DISABLED_TESTS_FILE = args.import_disabled_tests -LOG_SUFFIX = args.log_suffix -RUN_PARALLEL = args.run_parallel -TEST_BAILOUTS = args.test_bailouts -USE_PYTEST = args.use_pytest -PYTEST_SINGLE_TEST = args.pytest_single_test -TEST_DISCOVER = args.discover_tests -TEST_IN_SUBPROCESS = args.subprocess -TEST_SAVE_XML = args.save_xml -REPEAT_COUNT = args.repeat -SEED = args.seed -SHOWLOCALS = args.showlocals -if not getattr(expecttest, "ACCEPT", False): - expecttest.ACCEPT = args.accept -UNITTEST_ARGS = [sys.argv[0]] + remaining -torch.manual_seed(SEED) + def run_unittest_help(argv): + unittest.main(argv=argv) + + if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() + + args, remaining = parser.parse_known_args() + if args.jit_executor == 'legacy': + GRAPH_EXECUTOR = ProfilingMode.LEGACY + elif args.jit_executor == 'profiling': + GRAPH_EXECUTOR = ProfilingMode.PROFILING + elif args.jit_executor == 'simple': + GRAPH_EXECUTOR = ProfilingMode.SIMPLE + else: + # infer flags based on the default settings + GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() + + RERUN_DISABLED_TESTS = args.rerun_disabled_tests + + SLOW_TESTS_FILE = args.import_slow_tests + DISABLED_TESTS_FILE = args.import_disabled_tests + LOG_SUFFIX = args.log_suffix + RUN_PARALLEL = args.run_parallel + TEST_BAILOUTS = args.test_bailouts + USE_PYTEST = args.use_pytest + PYTEST_SINGLE_TEST = args.pytest_single_test + TEST_DISCOVER = args.discover_tests + TEST_IN_SUBPROCESS = args.subprocess + TEST_SAVE_XML = args.save_xml + REPEAT_COUNT = args.repeat + SEED = args.seed + SHOWLOCALS = args.showlocals + if not getattr(expecttest, "ACCEPT", False): + expecttest.ACCEPT = args.accept + UNITTEST_ARGS = [sys.argv[0]] + remaining + torch.manual_seed(SEED) # CI Prefix path used only on CI environment -CI_TEST_PREFIX = str(Path(os.getcwd())) -CI_PT_ROOT = str(Path(os.getcwd()).parent) -CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) + CI_TEST_PREFIX = str(Path(os.getcwd())) + CI_PT_ROOT = str(Path(os.getcwd()).parent) + CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p, timeout=None): try: @@ -1138,7 +1182,9 @@ def lint_test_case_extension(suite): return succeed -def get_report_path(argv=UNITTEST_ARGS, pytest=False): +def get_report_path(argv=None, pytest=False): + if argv is None: + argv = UNITTEST_ARGS test_filename = sanitize_test_filename(argv[0]) test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = os.path.join(test_report_path, test_filename) @@ -1189,7 +1235,11 @@ def pytest_collection_finish(self, session): return test_collector_plugin.tests -def run_tests(argv=UNITTEST_ARGS): +def run_tests(argv=None): + parse_cmd_line_args() + if argv is None: + argv = UNITTEST_ARGS + # import test files. if SLOW_TESTS_FILE: if os.path.exists(SLOW_TESTS_FILE): @@ -1748,6 +1798,7 @@ def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"): def decorator(fn): + assert GRAPH_EXECUTOR if not isinstance(fn, type): @wraps(fn) def wrapper(*args, **kwargs): From adc6560316b96cd834c3b580bf3d789b26440f80 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Tue, 24 Jun 2025 16:13:10 +0100 Subject: [PATCH 02/20] Use Optional instead of | --- torch/testing/_internal/common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 15034f543503..93594adf0d03 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -114,7 +114,7 @@ class ProfilingMode(Enum): CI_PT_ROOT = "" CI_TEST_PREFIX = "" DISABLED_TESTS_FILE = "" -GRAPH_EXECUTOR : ProfilingMode | None = None +GRAPH_EXECUTOR : Optional[ProfilingMode] = None LOG_SUFFIX = "" PYTEST_SINGLE_TEST = "" REPEAT_COUNT = 0 From edf4187f963a87452c97dc7961e7475e7d1aa7eb Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Wed, 25 Jun 2025 13:39:09 +0100 Subject: [PATCH 03/20] Fix test_jit_legacy test --- test/test_jit_legacy.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 5576f1645349..9bf9291e8886 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,7 +2,13 @@ import sys sys.argv.append("--jit-executor=legacy") -from test_jit import * # noqa: F403 +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests + +# The tests decorators depend on command line arguments +if __name__ == '__main__': + parse_cmd_line_args() + +from test_jit import * # noqa: F403, F401 if __name__ == '__main__': run_tests() From 2a3b177c5c0abe3e1912c868ea63310a9b9bc2c9 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Wed, 25 Jun 2025 14:48:25 +0100 Subject: [PATCH 04/20] Move assert to when the variable is actually read --- test/test_jit_legacy.py | 6 +----- torch/testing/_internal/common_utils.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 9bf9291e8886..76fa1c3e91ff 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,11 +2,7 @@ import sys sys.argv.append("--jit-executor=legacy") -from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests - -# The tests decorators depend on command line arguments -if __name__ == '__main__': - parse_cmd_line_args() +from torch.testing._internal.common_utils import run_tests from test_jit import * # noqa: F403, F401 diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 93594adf0d03..5285e3d7d61f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1798,10 +1798,10 @@ def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"): def decorator(fn): - assert GRAPH_EXECUTOR if not isinstance(fn, type): @wraps(fn) def wrapper(*args, **kwargs): + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.LEGACY: raise unittest.SkipTest(msg) else: From 36da9a91574b2a7068779498444f0cf98d161400 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 26 Jun 2025 15:11:13 +0100 Subject: [PATCH 05/20] Fix jit tests --- test/test_jit.py | 11 ++++++++++- test/test_jit_legacy.py | 7 ++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 3af3521f4fce..999cd0c3cf00 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -97,7 +97,7 @@ from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ - suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ + parse_cmd_line_args, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -147,6 +147,13 @@ import tracemalloc +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + +from torch.testing._internal.common_utils import GRAPH_EXECUTOR + def canonical(graph): return torch._C._jit_pass_canonicalize(graph).str(False) @@ -158,6 +165,7 @@ def doAutodiffCheck(testname): if "test_t_" in testname or testname == "test_t": return False + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: return False @@ -201,6 +209,7 @@ def doAutodiffCheck(testname): return testname not in test_exceptions +assert GRAPH_EXECUTOR # TODO: enable TE in PE when all tests are fixed torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 76fa1c3e91ff..480b57a55bd4 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,7 +2,12 @@ import sys sys.argv.append("--jit-executor=legacy") -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests + +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() from test_jit import * # noqa: F403, F401 From 65210f6608f1b859c24714ac63e00b25d5d3c851 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 26 Jun 2025 17:01:20 +0100 Subject: [PATCH 06/20] Set seed for distributed_test.py --- torch/testing/_internal/distributed/distributed_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28b761a37d58..2428545c44d4 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -396,6 +396,7 @@ def forward(self, x): return F.relu(self.lin1(x)) +torch.manual_seed(1234) DDP_NET = Net() BN_NET = BatchNormNet() BN_NET_NO_AFFINE = BatchNormNet(affine=False) From a4c5ee3444ce1dd013ca567a72067792db2ae131 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 10 Jul 2025 12:52:19 +0100 Subject: [PATCH 07/20] Fix seed setting --- torch/testing/_internal/common_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index a48d85d58d09..ffda4b5066ab 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -120,7 +120,7 @@ class ProfilingMode(Enum): REPEAT_COUNT = 0 RERUN_DISABLED_TESTS = False RUN_PARALLEL = 0 -SEED = 0 +SEED : Optional[int] = None SHOWLOCALS = False SLOW_TESTS_FILE = "" TEST_BAILOUTS = False @@ -130,7 +130,6 @@ class ProfilingMode(Enum): UNITTEST_ARGS : list[str] = [] USE_PYTEST = False - def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -2428,7 +2427,10 @@ def get_function_arglist(func): return inspect.getfullargspec(func).args -def set_rng_seed(seed): +def set_rng_seed(seed =None): + if seed is None: + assert SEED is not None + seed = SEED torch.manual_seed(seed) random.seed(seed) if TEST_NUMPY: @@ -3449,7 +3451,7 @@ def run(self, result=None): def setUp(self): check_if_enable(self) - set_rng_seed(SEED) + set_rng_seed() # Save global check sparse tensor invariants state that can be # restored from tearDown: @@ -5805,3 +5807,4 @@ def recover(): torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p return recover()(fn) + From 0ece6144ee522c630dab5e3e91a7ed88fbe1ef0b Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 10 Jul 2025 15:02:30 +0100 Subject: [PATCH 08/20] Fixing jit and distributed tests --- test/test_jit_autocast.py | 9 +++++++-- test/test_jit_fuser_te.py | 7 +++++++ torch/testing/_internal/common_utils.py | 3 +-- torch/testing/_internal/distributed/distributed_test.py | 1 + 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index b3cf4d9bee8f..dcdf78ff4b89 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -5,12 +5,17 @@ from typing import Optional import unittest -from test_jit import JitTestCase from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo from torch.testing import FileCheck from jit.test_models import MnistNet +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + +from test_jit import JitTestCase TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() @skipIfTorchDynamo("Not a TorchDynamo suitable test") diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 8d3a8090c67a..16645422e080 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -22,6 +22,13 @@ torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from itertools import combinations, permutations, product from textwrap import dedent diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index ffda4b5066ab..251da67a35ef 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2427,7 +2427,7 @@ def get_function_arglist(func): return inspect.getfullargspec(func).args -def set_rng_seed(seed =None): +def set_rng_seed(seed=None): if seed is None: assert SEED is not None seed = SEED @@ -5807,4 +5807,3 @@ def recover(): torch.backends.cuda.matmul.fp32_precision = old_cuda_matmul_p return recover()(fn) - diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 2428545c44d4..3c9e85bf64f1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,6 +137,7 @@ def eq(value, other): f = Foo(10) f.bar = 1 +torch.manual_seed(0) foo_cpu_tensor = Foo(torch.randn(3, 3)) From a9c889905746cc57f41cc70c43d8dc477d7b3212 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 10 Jul 2025 15:09:17 +0100 Subject: [PATCH 09/20] Clean up --- torch/testing/_internal/distributed/distributed_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 3c9e85bf64f1..16534af82111 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,7 +137,7 @@ def eq(value, other): f = Foo(10) f.bar = 1 -torch.manual_seed(0) +torch.manual_seed(1234) foo_cpu_tensor = Foo(torch.randn(3, 3)) @@ -397,7 +397,6 @@ def forward(self, x): return F.relu(self.lin1(x)) -torch.manual_seed(1234) DDP_NET = Net() BN_NET = BatchNormNet() BN_NET_NO_AFFINE = BatchNormNet(affine=False) From e6dc730691da8dcc96c7e249d92267ec05a9eb25 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 10 Jul 2025 17:17:00 +0100 Subject: [PATCH 10/20] Fix test_jit_fuser --- test/test_jit_fuser.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 1ac7803a9d46..5446770695c4 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -9,6 +9,13 @@ from torch.testing import FileCheck from unittest import skipIf +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ From a4b833afcbd98dfaa17000c2c7e97d30c4d1dbd6 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 10 Jul 2025 18:02:38 +0100 Subject: [PATCH 11/20] Relax checks in set_rng_seed() --- test/jit/test_autodiff_subgraph_slicing.py | 1 + test/test_jit.py | 16 ++++++++-------- torch/testing/_internal/common_utils.py | 13 +++++++++++-- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index f42aa7f8f436..ff44c1e0f8f4 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -27,6 +27,7 @@ ) +assert GRAPH_EXECUTOR is not None @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) diff --git a/test/test_jit.py b/test/test_jit.py index 914ff484b2db..6a765102a596 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3,6 +3,13 @@ import torch +if __name__ == '__main__': + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + # This is how we include tests located in test/jit/... # They are included here so that they are invoked when you call `test_jit.py`, # do not run these test files directly. @@ -97,7 +104,7 @@ from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ - parse_cmd_line_args, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ + GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -147,13 +154,6 @@ import tracemalloc -if __name__ == '__main__': - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed - # before instantiating tests. - parse_cmd_line_args() - -from torch.testing._internal.common_utils import GRAPH_EXECUTOR - def canonical(graph): return torch._C._jit_pass_canonicalize(graph).str(False) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 251da67a35ef..f7c124391ddf 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2429,8 +2429,17 @@ def get_function_arglist(func): def set_rng_seed(seed=None): if seed is None: - assert SEED is not None - seed = SEED + if SEED is not None: + seed = SEED + else: + # Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class. + # So just print a warning and hardcode the seed. + seed = 1234 + msg = ("set_rng_seed() was called without providing a seed and the command line " + f"arguments haven't been parsed so the seed will be set to {seed}. " + "To remove this warning make sure your test is run via run_tests() or " + "parse_cmd_line_args() is called before set_rng_seed() is called.") + warnings.warn(msg) torch.manual_seed(seed) random.seed(seed) if TEST_NUMPY: From 79faf6c2116a0718bbf2491a24608849b94df3e3 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Wed, 16 Jul 2025 11:46:04 +0100 Subject: [PATCH 12/20] Fix more tests --- test/jit/test_autodiff_subgraph_slicing.py | 2 + test/test_jit_fuser_legacy.py | 8 ++++ torch/testing/_internal/common_distributed.py | 15 +++++++- .../_internal/distributed/distributed_test.py | 37 ++++++++++--------- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index ff44c1e0f8f4..f128f9c7eec3 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -28,6 +28,8 @@ assert GRAPH_EXECUTOR is not None + + @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) diff --git a/test/test_jit_fuser_legacy.py b/test/test_jit_fuser_legacy.py index 3bd8c9497ce0..4100bcc3e182 100644 --- a/test/test_jit_fuser_legacy.py +++ b/test/test_jit_fuser_legacy.py @@ -2,6 +2,14 @@ import sys sys.argv.append("--jit-executor=legacy") + +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from test_jit_fuser import * # noqa: F403 if __name__ == '__main__': diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 670891dfe77f..3271229e06a3 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -32,6 +32,7 @@ from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch._logging._internal import trace_log +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( FILE_SCHEMA, find_free_port, @@ -715,12 +716,19 @@ def _current_test_name(self) -> str: def _start_processes(self, proc) -> None: self.processes = [] + assert common_utils.SEED is not None for rank in range(int(self.world_size)): parent_conn, child_conn = torch.multiprocessing.Pipe() process = proc( target=self.__class__._run, name="process " + str(rank), - args=(rank, self._current_test_name(), self.file_name, child_conn), + args=( + rank, + self._current_test_name(), + self.file_name, + child_conn, + common_utils.SEED, + ), kwargs={ "fake_pg": getattr(self, "fake_pg", False), }, @@ -775,11 +783,12 @@ def _event_listener(parent_pipe, signal_pipe, rank: int): @classmethod def _run( - cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs + cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs ) -> None: self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed self.run_test(test_name, parent_pipe) def run_test(self, test_name: str, parent_pipe) -> None: @@ -798,6 +807,8 @@ def run_test(self, test_name: str, parent_pipe) -> None: # Show full C++ stacktraces when a Python error originating from C++ is raised. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + common_utils.set_rng_seed(self.seed) + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' # We're retrieving a corresponding test and executing it. try: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 16534af82111..d7adcdf92407 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,17 +137,18 @@ def eq(value, other): f = Foo(10) f.bar = 1 -torch.manual_seed(1234) -foo_cpu_tensor = Foo(torch.randn(3, 3)) +# Defer instantiation until the seed is set so that randn() returns the same +# values in all processes. +def create_collectives_object_test_list(): + return [ + {"key1": 3, "key2": 4, "key3": {"nested": True}}, + f, + Foo(torch.randn(3, 3)), + "foo", + [1, 2, True, "string", [4, 5, "nested"]], + ] -COLLECTIVES_OBJECT_TEST_LIST = [ - {"key1": 3, "key2": 4, "key3": {"nested": True}}, - f, - foo_cpu_tensor, - "foo", - [1, 2, True, "string", [4, 5, "nested"]], -] # Allowlist of distributed backends where profiling collectives is supported. PROFILING_SUPPORTED_BACKENDS = [ @@ -595,12 +596,13 @@ def destroy_pg_upon_exit(self) -> bool: return False @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): + def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): if BACKEND == "nccl" and not torch.cuda.is_available(): sys.exit(TEST_SKIPS["no_cuda"].exit_code) self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed if torch.cuda.is_available() and torch.cuda.device_count() < int( self.world_size @@ -6645,7 +6647,7 @@ def validate_global_samples(local_num_samples): def _test_allgather_object(self, subgroup=None): # Only set device for NCCL backend since it must use GPUs. - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() backend = os.environ["BACKEND"] if backend == "nccl": @@ -6689,7 +6691,7 @@ def test_all_gather_object_subgroup(self): def _test_gather_object(self, pg=None): # Ensure stateful objects can be gathered - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() my_rank = dist.get_rank(pg) backend = os.environ["BACKEND"] @@ -7555,7 +7557,7 @@ def forward(self, _): loss.backward() def _test_broadcast_object_list(self, group=None): - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() # Only set device for NCCL backend since it must use GPUs. # Case where rank != GPU device. @@ -8279,10 +8281,11 @@ def forward(self, x): @require_backend_is_available({"gloo"}) def test_scatter_object_list(self): src_rank = 0 + collectives_object_test_list = create_collectives_object_test_list() scatter_list = ( - COLLECTIVES_OBJECT_TEST_LIST + collectives_object_test_list if self.rank == src_rank - else [None for _ in COLLECTIVES_OBJECT_TEST_LIST] + else [None for _ in collectives_object_test_list] ) world_size = dist.get_world_size() scatter_list = scatter_list[:world_size] @@ -8295,8 +8298,8 @@ def test_scatter_object_list(self): dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) self.assertEqual( output_obj_list[0], - COLLECTIVES_OBJECT_TEST_LIST[ - self.rank % len(COLLECTIVES_OBJECT_TEST_LIST) + collectives_object_test_list[ + self.rank % len(collectives_object_test_list) ], ) # Ensure errors are raised upon incorrect arguments. From fc93a2a0b39704ced4896d789db7b2ea2bce483f Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Wed, 16 Jul 2025 14:25:56 +0100 Subject: [PATCH 13/20] Fix distributed model initialisation --- .../_internal/distributed/distributed_test.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index d7adcdf92407..28cd7efc3226 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -398,12 +398,6 @@ def forward(self, x): return F.relu(self.lin1(x)) -DDP_NET = Net() -BN_NET = BatchNormNet() -BN_NET_NO_AFFINE = BatchNormNet(affine=False) -ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) - - def get_timeout(test_id): test_name = test_id.split(".")[-1] if test_name in CUSTOMIZED_TIMEOUT: @@ -4290,7 +4284,7 @@ def _test_DistributedDataParallel( # as baseline # cpu training setup - model = DDP_NET + model = Net() # single gpu training setup model_gpu = copy.deepcopy(model) @@ -4345,7 +4339,7 @@ def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False): _group, _group_id, rank = self._init_global_test() # cpu training setup - model_base = DDP_NET + model_base = Net() # DDP-CPU training setup model_DDP = copy.deepcopy(model_base) @@ -5494,7 +5488,7 @@ def test_DistributedDataParallel(self): def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): torch.manual_seed(31415) # Creates model and optimizer in default precision - model = copy.deepcopy(DDP_NET).cuda() + model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.03) # Creates a GradScaler once at the beginning of training. @@ -5579,7 +5573,7 @@ def _test_DistributedDataParallel_SyncBatchNorm( # as baseline # cpu training setup - model = BN_NET if affine else BN_NET_NO_AFFINE + model = BatchNormNet() if affine else BatchNormNet(affine=False) # single gpu training setup model_gpu = copy.deepcopy(model) @@ -5629,6 +5623,7 @@ def _test_DistributedDataParallel_SyncBatchNorm( def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): learning_rate = 0.03 + DDP_NET = Net() net = torch.nn.parallel.DistributedDataParallel( copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank], @@ -5695,7 +5690,7 @@ def _test_post_localSGD_optimizer_step_reload( learning_rate = 0.03 net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( - copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank] + Net().cuda(), device_ids=[self.rank] ) averager = create_averager() @@ -5845,7 +5840,7 @@ def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format( bs_offset = int(rank * 2) global_bs = int(num_processes * 2) - model = ONLY_SBN_NET + model = nn.SyncBatchNorm(2, momentum=0.99) model_gpu = copy.deepcopy(model).cuda(rank) model_DDP = nn.parallel.DistributedDataParallel( model_gpu, device_ids=[rank] @@ -6055,6 +6050,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self): def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( self, ): + ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) _group, _group_id, rank = self._init_global_test() model = nn.parallel.DistributedDataParallel( ONLY_SBN_NET.cuda(rank), device_ids=[rank] @@ -6122,7 +6118,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): def test_DistributedDataParallel_SyncBatchNorm_half(self): _group, _group_id, rank = self._init_global_test() - model = copy.deepcopy(BN_NET) + model = BatchNormNet() model = model.half() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.parallel.DistributedDataParallel( @@ -6138,7 +6134,7 @@ def test_DistributedDataParallel_SyncBatchNorm_half(self): def _test_ddp_logging_data(self, is_gpu): rank = dist.get_rank() - model_DDP = copy.deepcopy(DDP_NET) + model_DDP = Net() if is_gpu: model_DDP = nn.parallel.DistributedDataParallel( model_DDP.cuda(rank), device_ids=[rank] @@ -6414,7 +6410,7 @@ def test_ddp_logging_data_gpu(self): BACKEND == "nccl", "nccl does not support DDP on CPU models" ) def test_static_graph_api_cpu(self): - model_DDP = nn.parallel.DistributedDataParallel(DDP_NET) + model_DDP = nn.parallel.DistributedDataParallel(Net()) expected_err = "should be called before training loop starts" with self.assertRaisesRegex(RuntimeError, expected_err): local_bs = 2 @@ -7261,7 +7257,7 @@ def forward(self, x): return x torch.cuda.set_device(self.rank) - model_bn = BN_NET + model_bn = BatchNormNet() model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( copy.deepcopy(model_bn) ).cuda(self.rank) @@ -9985,7 +9981,7 @@ def forward(self, x): "Only Nccl & Gloo backend support DistributedDataParallel", ) def test_sync_bn_logged(self): - model = BN_NET + model = BatchNormNet() rank = self.rank # single gpu training setup model_gpu = model.cuda(rank) From b87a0cdc77ed336b9f78f3cb0f4127fc12462d66 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Wed, 16 Jul 2025 17:08:17 +0100 Subject: [PATCH 14/20] Make seed an optional argument --- torch/fx/traceback.py | 4 ++-- torch/testing/_internal/common_distributed.py | 2 +- torch/testing/_internal/distributed/distributed_test.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index e57e89ea8d4b..b2a0e0f681e8 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -81,8 +81,8 @@ def __init__( self.from_node = [] # cache the action string and dict representation for performance. - self._action_string = None - self._dict = None + self._action_string: Optional[str] = None + self._dict: Optional[dict[str, Any]] = None @property def name(self) -> str: diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index a1e2cb11a746..4c151f8f557d 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -727,9 +727,9 @@ def _start_processes(self, proc) -> None: self._current_test_name(), self.file_name, child_conn, - common_utils.SEED, ), kwargs={ + "seed": common_utils.SEED, "fake_pg": getattr(self, "fake_pg", False), }, ) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28cd7efc3226..5edbedc912fa 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -590,13 +590,12 @@ def destroy_pg_upon_exit(self) -> bool: return False @classmethod - def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): + def _run(cls, rank, test_name, file_name, pipe, **kwargs): if BACKEND == "nccl" and not torch.cuda.is_available(): sys.exit(TEST_SKIPS["no_cuda"].exit_code) self = cls(test_name) self.rank = rank self.file_name = file_name - self.seed = seed if torch.cuda.is_available() and torch.cuda.device_count() < int( self.world_size From 786db21fb86a76a98273f4efffaa89894bdae49b Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Thu, 17 Jul 2025 08:54:14 +0100 Subject: [PATCH 15/20] Pass the seed for subclasses too --- torch/fx/traceback.py | 4 ++-- torch/testing/_internal/common_distributed.py | 7 +++++-- torch/testing/_internal/distributed/distributed_test.py | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index b2a0e0f681e8..e57e89ea8d4b 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -81,8 +81,8 @@ def __init__( self.from_node = [] # cache the action string and dict representation for performance. - self._action_string: Optional[str] = None - self._dict: Optional[dict[str, Any]] = None + self._action_string = None + self._dict = None @property def name(self) -> str: diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 4c151f8f557d..e48f2351aa27 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -672,6 +672,7 @@ def __init__( if methodName != "runTest": method_name = methodName super().__init__(method_name) + self.seed = None try: fn = getattr(self, method_name) setattr(self, method_name, self.join_or_run(fn)) @@ -807,7 +808,8 @@ def run_test(self, test_name: str, parent_pipe) -> None: # Show full C++ stacktraces when a Python error originating from C++ is raised. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" - common_utils.set_rng_seed(self.seed) + if self.seed is not None: + common_utils.set_rng_seed(self.seed) # self.id() == e.g. '__main__.TestDistributed.test_get_rank' # We're retrieving a corresponding test and executing it. @@ -1546,7 +1548,7 @@ def world_size(self) -> int: @classmethod def _run( - cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs + cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs ) -> None: trace_log.addHandler(logging.NullHandler()) @@ -1554,6 +1556,7 @@ def _run( self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed self.run_test(test_name, parent_pipe) diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 5edbedc912fa..28cd7efc3226 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -590,12 +590,13 @@ def destroy_pg_upon_exit(self) -> bool: return False @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): + def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): if BACKEND == "nccl" and not torch.cuda.is_available(): sys.exit(TEST_SKIPS["no_cuda"].exit_code) self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed if torch.cuda.is_available() and torch.cuda.device_count() < int( self.world_size From 6c9730cd318a912fd48cd7b5cd14bf17be77505c Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 18 Jul 2025 15:32:16 +0100 Subject: [PATCH 16/20] Handle seed in NCCLTraceTestBase --- test/distributed/test_c10d_nccl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index fd9e7594828d..d15cfb1c75b3 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -4280,10 +4280,11 @@ def _run( test_name: str, file_name: str, parent_pipe, + seed: int, **kwargs, ) -> None: cls.parent = parent_conn - super()._run(rank, test_name, file_name, parent_pipe) + super()._run(rank, test_name, file_name, parent_pipe, seed) @property def local_device(self): From 2412e1373bbdaf1637143ddcd4a715eff701320f Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 21 Jul 2025 09:38:02 +0100 Subject: [PATCH 17/20] Fix seed in FSDP tests --- torch/testing/_internal/common_fsdp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index a9e24eb90ef8..626a9b8494e4 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -57,6 +57,7 @@ from torch.testing._internal.common_utils import ( FILE_SCHEMA, get_cycles_per_ms, + set_rng_seed, TEST_CUDA, TEST_HPU, TEST_XPU, @@ -1180,7 +1181,7 @@ def run_subtests(self, *args, **kwargs): return run_subtests(self, *args, **kwargs) @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override] + def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): # type: ignore[override] self = cls(test_name) self.rank = rank self.file_name = file_name @@ -1226,6 +1227,7 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[overr dist.barrier(device_ids=device_ids) torch._dynamo.reset() + set_rng_seed(seed) self.run_test(test_name, pipe) torch._dynamo.reset() From 5fe1c4bd38d258589d0b0d234674a77b8c1b2415 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Fri, 1 Aug 2025 17:04:15 +0100 Subject: [PATCH 18/20] Fix test_nn.py --- test/test_cpp_api_parity.py | 5 +++++ test/test_expanded_weights.py | 5 +++++ test/test_jit.py | 2 +- test/test_nn.py | 7 +++++++ torch/testing/_internal/common_nn.py | 2 ++ 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index 2193243b751e..00971c665a15 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -35,6 +35,11 @@ class TestCppApiParity(common.TestCase): functional_test_params_map = {} +if __name__ == '__main__': + # The value of the SEED depends on command line arguments so make sure they're parsed + # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn + common.parse_cmd_line_args() + expected_test_params_dicts = [] for test_params_dicts, test_instance_class in [ diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index 02bf6d776568..ed04da1e65e0 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -1007,6 +1007,11 @@ def filter_supported_tests(t): if "module_name" in t and t["module_name"] in supported_modules: return True +if __name__ == '__main__': + from torch.testing._internal.common_utils import parse_cmd_line_args + # The value of the SEED depends on command line arguments so make sure they're parsed + # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn + parse_cmd_line_args() # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # These currently use the legacy nn tests diff --git a/test/test_jit.py b/test/test_jit.py index 6a765102a596..814b48449b14 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6,7 +6,7 @@ if __name__ == '__main__': from torch.testing._internal.common_utils import parse_cmd_line_args - # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed # before instantiating tests. parse_cmd_line_args() diff --git a/test/test_nn.py b/test/test_nn.py index 218a65f388f0..e5a3357b1bec 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7627,6 +7627,13 @@ def with_tf32_on(self, test=test, kwargs=kwargs): else: add(cuda_test_name, with_tf32_off) +if __name__ == '__main__': + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of the SEED depends on command line arguments so make sure they're parsed + # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn + parse_cmd_line_args() + for test_params in module_tests + get_new_module_tests(): # TODO: CUDA is not implemented yet if 'constructor' not in test_params: diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 135cc6a7bd66..b42114d7d0cd 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -15,6 +15,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import _reduction as _Reduction +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater @@ -1078,6 +1079,7 @@ def unsqueeze_inp(inp): def get_new_module_tests(): + assert common_utils.SEED is not None, "Make sure the seed is set before calling get_new_module_tests()" new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), From 59bf8602c73e5c8cebf2ba3010275528e57ea1c4 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 11 Aug 2025 10:12:08 +0100 Subject: [PATCH 19/20] Trying to hardcode the seed for distributed tests in _start_processes to see if it's enough --- torch/testing/_internal/common_distributed.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 48cbe30f5967..f24f92886a90 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -15,6 +15,7 @@ import traceback import types import unittest +import warning from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta @@ -717,7 +718,13 @@ def _current_test_name(self) -> str: def _start_processes(self, proc) -> None: self.processes = [] - assert common_utils.SEED is not None + # distributed tests don't support setting the seed via the command line so hardcode it here. + hardcoded_seed = 1234 + if common_utils.SEED and common_utils.SEED != hardcoded_seed: + msg = ("Distributed tests do not support setting the seed via the command line. " + f"the seed will be reset to its default value ({hardcoded_seed} now") + warning.warn(msg) + common_utils.SEED = hardcoded_seed for rank in range(int(self.world_size)): parent_conn, child_conn = torch.multiprocessing.Pipe() process = proc( From 0663f0850d473fba93ab9feb7d448f8ce1842644 Mon Sep 17 00:00:00 2001 From: Anthony Barbier Date: Mon, 11 Aug 2025 11:59:22 +0100 Subject: [PATCH 20/20] Use logger.warning --- torch/testing/_internal/common_distributed.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index f24f92886a90..6bb9fe1b4fa7 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -15,7 +15,6 @@ import traceback import types import unittest -import warning from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta @@ -723,7 +722,7 @@ def _start_processes(self, proc) -> None: if common_utils.SEED and common_utils.SEED != hardcoded_seed: msg = ("Distributed tests do not support setting the seed via the command line. " f"the seed will be reset to its default value ({hardcoded_seed} now") - warning.warn(msg) + logger.warning(msg) common_utils.SEED = hardcoded_seed for rank in range(int(self.world_size)): parent_conn, child_conn = torch.multiprocessing.Pipe()