diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index fd9e7594828d..d15cfb1c75b3 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -4280,10 +4280,11 @@ def _run( test_name: str, file_name: str, parent_pipe, + seed: int, **kwargs, ) -> None: cls.parent = parent_conn - super()._run(rank, test_name, file_name, parent_pipe) + super()._run(rank, test_name, file_name, parent_pipe, seed) @property def local_device(self): diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py index f42aa7f8f436..f128f9c7eec3 100644 --- a/test/jit/test_autodiff_subgraph_slicing.py +++ b/test/jit/test_autodiff_subgraph_slicing.py @@ -27,6 +27,9 @@ ) +assert GRAPH_EXECUTOR is not None + + @unittest.skipIf( GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients" ) diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index 2193243b751e..480df4780121 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -35,6 +35,11 @@ class TestCppApiParity(common.TestCase): functional_test_params_map = {} +if __name__ == "__main__": + # The value of the SEED depends on command line arguments so make sure they're parsed + # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn + common.parse_cmd_line_args() + expected_test_params_dicts = [] for test_params_dicts, test_instance_class in [ diff --git a/test/test_expanded_weights.py b/test/test_expanded_weights.py index 02bf6d776568..3696a1c43f43 100644 --- a/test/test_expanded_weights.py +++ b/test/test_expanded_weights.py @@ -1008,6 +1008,13 @@ def filter_supported_tests(t): return True +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of the SEED depends on command line arguments so make sure they're parsed + # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn + parse_cmd_line_args() + # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests # These currently use the legacy nn tests supported_tests = [ diff --git a/test/test_jit.py b/test/test_jit.py index c86fb111bfb8..814b48449b14 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3,6 +3,13 @@ import torch +if __name__ == '__main__': + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + # This is how we include tests located in test/jit/... # They are included here so that they are invoked when you call `test_jit.py`, # do not run these test files directly. @@ -97,7 +104,7 @@ from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ - suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \ + GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ skipIfCrossRef, skipIfTorchDynamo @@ -158,6 +165,7 @@ def doAutodiffCheck(testname): if "test_t_" in testname or testname == "test_t": return False + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: return False @@ -201,6 +209,7 @@ def doAutodiffCheck(testname): return testname not in test_exceptions +assert GRAPH_EXECUTOR # TODO: enable TE in PE when all tests are fixed torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index b3cf4d9bee8f..dcdf78ff4b89 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -5,12 +5,17 @@ from typing import Optional import unittest -from test_jit import JitTestCase from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo from torch.testing import FileCheck from jit.test_models import MnistNet +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + +from test_jit import JitTestCase TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported() @skipIfTorchDynamo("Not a TorchDynamo suitable test") diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 1ac7803a9d46..5446770695c4 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -9,6 +9,13 @@ from torch.testing import FileCheck from unittest import skipIf +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ diff --git a/test/test_jit_fuser_legacy.py b/test/test_jit_fuser_legacy.py index 3bd8c9497ce0..4100bcc3e182 100644 --- a/test/test_jit_fuser_legacy.py +++ b/test/test_jit_fuser_legacy.py @@ -2,6 +2,14 @@ import sys sys.argv.append("--jit-executor=legacy") + +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from test_jit_fuser import * # noqa: F403 if __name__ == '__main__': diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3e26d37da1b..1bda41f7f8f1 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -22,6 +22,13 @@ torch._C._jit_set_profiling_executor(True) torch._C._get_graph_executor_optimize(True) +if __name__ == "__main__": + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + from itertools import combinations, permutations, product from textwrap import dedent diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 5576f1645349..480b57a55bd4 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -2,7 +2,14 @@ import sys sys.argv.append("--jit-executor=legacy") -from test_jit import * # noqa: F403 +from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests + +if __name__ == '__main__': + # The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed + # before instantiating tests. + parse_cmd_line_args() + +from test_jit import * # noqa: F403, F401 if __name__ == '__main__': run_tests() diff --git a/test/test_nn.py b/test/test_nn.py index 904b819a6fc4..207b1902a37d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7643,6 +7643,13 @@ def with_tf32_on(self, test=test, kwargs=kwargs): else: add(cuda_test_name, with_tf32_off) +if __name__ == '__main__': + from torch.testing._internal.common_utils import parse_cmd_line_args + + # The value of the SEED depends on command line arguments so make sure they're parsed + # before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn + parse_cmd_line_args() + for test_params in module_tests + get_new_module_tests(): # TODO: CUDA is not implemented yet if 'constructor' not in test_params: diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index d4cc6cde3cc5..6bb9fe1b4fa7 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -32,6 +32,7 @@ from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory from torch._logging._internal import trace_log +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( FILE_SCHEMA, find_free_port, @@ -671,6 +672,7 @@ def __init__( if methodName != "runTest": method_name = methodName super().__init__(method_name) + self.seed = None try: fn = getattr(self, method_name) setattr(self, method_name, self.join_or_run(fn)) @@ -715,13 +717,26 @@ def _current_test_name(self) -> str: def _start_processes(self, proc) -> None: self.processes = [] + # distributed tests don't support setting the seed via the command line so hardcode it here. + hardcoded_seed = 1234 + if common_utils.SEED and common_utils.SEED != hardcoded_seed: + msg = ("Distributed tests do not support setting the seed via the command line. " + f"the seed will be reset to its default value ({hardcoded_seed} now") + logger.warning(msg) + common_utils.SEED = hardcoded_seed for rank in range(int(self.world_size)): parent_conn, child_conn = torch.multiprocessing.Pipe() process = proc( target=self.__class__._run, name="process " + str(rank), - args=(rank, self._current_test_name(), self.file_name, child_conn), + args=( + rank, + self._current_test_name(), + self.file_name, + child_conn, + ), kwargs={ + "seed": common_utils.SEED, "fake_pg": getattr(self, "fake_pg", False), }, ) @@ -775,11 +790,12 @@ def _event_listener(parent_pipe, signal_pipe, rank: int): @classmethod def _run( - cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs + cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs ) -> None: self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed self.run_test(test_name, parent_pipe) def run_test(self, test_name: str, parent_pipe) -> None: @@ -798,6 +814,9 @@ def run_test(self, test_name: str, parent_pipe) -> None: # Show full C++ stacktraces when a Python error originating from C++ is raised. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + if self.seed is not None: + common_utils.set_rng_seed(self.seed) + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' # We're retrieving a corresponding test and executing it. try: @@ -1535,7 +1554,7 @@ def world_size(self) -> int: @classmethod def _run( - cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs + cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs ) -> None: trace_log.addHandler(logging.NullHandler()) @@ -1543,6 +1562,7 @@ def _run( self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed self.run_test(test_name, parent_pipe) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 0e50762893d7..25094049c2ee 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -57,6 +57,7 @@ from torch.testing._internal.common_utils import ( FILE_SCHEMA, get_cycles_per_ms, + set_rng_seed, TEST_CUDA, TEST_HPU, TEST_XPU, @@ -1180,7 +1181,7 @@ def run_subtests(self, *args, **kwargs): return run_subtests(self, *args, **kwargs) @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override] + def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): # type: ignore[override] self = cls(test_name) self.rank = rank self.file_name = file_name @@ -1226,6 +1227,7 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[overr dist.barrier(device_ids=device_ids) torch._dynamo.reset() + set_rng_seed(seed) self.run_test(test_name, pipe) torch._dynamo.reset() diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 135cc6a7bd66..b42114d7d0cd 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -15,6 +15,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.nn import _reduction as _Reduction +from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \ gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater @@ -1078,6 +1079,7 @@ def unsqueeze_inp(inp): def get_new_module_tests(): + assert common_utils.SEED is not None, "Make sure the seed is set before calling get_new_module_tests()" new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index bfc568bc1464..070d4657f8bb 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -104,6 +104,31 @@ MI300_ARCH = ("gfx942",) +class ProfilingMode(Enum): + LEGACY = 1 + SIMPLE = 2 + PROFILING = 3 + +# Set by parse_cmd_line_args() if called +CI_FUNCTORCH_ROOT = "" +CI_PT_ROOT = "" +CI_TEST_PREFIX = "" +DISABLED_TESTS_FILE = "" +GRAPH_EXECUTOR : Optional[ProfilingMode] = None +LOG_SUFFIX = "" +PYTEST_SINGLE_TEST = "" +REPEAT_COUNT = 0 +RERUN_DISABLED_TESTS = False +RUN_PARALLEL = 0 +SEED : Optional[int] = None +SHOWLOCALS = False +SLOW_TESTS_FILE = "" +TEST_BAILOUTS = False +TEST_DISCOVER = False +TEST_IN_SUBPROCESS = False +TEST_SAVE_XML = "" +UNITTEST_ARGS : list[str] = [] +USE_PYTEST = False def freeze_rng_state(*args, **kwargs): return torch.testing._utils.freeze_rng_state(*args, **kwargs) @@ -839,11 +864,6 @@ def test_wrapper(*args, **kwargs): yield (test_wrapper, test_name, {}, decorator_fn) -class ProfilingMode(Enum): - LEGACY = 1 - SIMPLE = 2 - PROFILING = 3 - def cppProfilingFlagsToProfilingMode(): old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -862,6 +882,7 @@ def cppProfilingFlagsToProfilingMode(): def enable_profiling_mode_for_profiling_tests(): old_prof_exec_state = False old_prof_mode_state = False + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: old_prof_exec_state = torch._C._jit_set_profiling_executor(True) old_prof_mode_state = torch._C._get_graph_executor_optimize(True) @@ -896,6 +917,7 @@ def num_profiled_runs(num_runs): def prof_callable(callable, *args, **kwargs): if 'profile_and_replay' in kwargs: del kwargs['profile_and_replay'] + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.PROFILING: with enable_profiling_mode_for_profiling_tests(): callable(*args, **kwargs) @@ -925,72 +947,93 @@ def _get_test_report_path(): test_source = override if override is not None else 'python-unittest' return os.path.join('test-reports', test_source) -is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") -parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) -parser.add_argument('--subprocess', action='store_true', - help='whether to run each test in a subprocess') -parser.add_argument('--seed', type=int, default=1234) -parser.add_argument('--accept', action='store_true') -parser.add_argument('--jit-executor', '--jit_executor', type=str) -parser.add_argument('--repeat', type=int, default=1) -parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') -parser.add_argument('--use-pytest', action='store_true') -parser.add_argument('--save-xml', nargs='?', type=str, - const=_get_test_report_path(), - default=_get_test_report_path() if IS_CI else None) -parser.add_argument('--discover-tests', action='store_true') -parser.add_argument('--log-suffix', type=str, default="") -parser.add_argument('--run-parallel', type=int, default=1) -parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) -parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) -parser.add_argument('--rerun-disabled-tests', action='store_true') -parser.add_argument('--pytest-single-test', type=str, nargs=1) -parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) +def parse_cmd_line_args(): + global CI_FUNCTORCH_ROOT + global CI_PT_ROOT + global CI_TEST_PREFIX + global DISABLED_TESTS_FILE + global GRAPH_EXECUTOR + global LOG_SUFFIX + global PYTEST_SINGLE_TEST + global REPEAT_COUNT + global RERUN_DISABLED_TESTS + global RUN_PARALLEL + global SEED + global SHOWLOCALS + global SLOW_TESTS_FILE + global TEST_BAILOUTS + global TEST_DISCOVER + global TEST_IN_SUBPROCESS + global TEST_SAVE_XML + global UNITTEST_ARGS + global USE_PYTEST + + is_running_via_run_test = "run_test.py" in getattr(__main__, "__file__", "") + parser = argparse.ArgumentParser(add_help=not is_running_via_run_test, allow_abbrev=False) + parser.add_argument('--subprocess', action='store_true', + help='whether to run each test in a subprocess') + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--accept', action='store_true') + parser.add_argument('--jit-executor', '--jit_executor', type=str) + parser.add_argument('--repeat', type=int, default=1) + parser.add_argument('--test-bailouts', '--test_bailouts', action='store_true') + parser.add_argument('--use-pytest', action='store_true') + parser.add_argument('--save-xml', nargs='?', type=str, + const=_get_test_report_path(), + default=_get_test_report_path() if IS_CI else None) + parser.add_argument('--discover-tests', action='store_true') + parser.add_argument('--log-suffix', type=str, default="") + parser.add_argument('--run-parallel', type=int, default=1) + parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) + parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) + parser.add_argument('--rerun-disabled-tests', action='store_true') + parser.add_argument('--pytest-single-test', type=str, nargs=1) + parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False) # Only run when -h or --help flag is active to display both unittest and parser help messages. -def run_unittest_help(argv): - unittest.main(argv=argv) - -if '-h' in sys.argv or '--help' in sys.argv: - help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) - help_thread.start() - help_thread.join() - -args, remaining = parser.parse_known_args() -if args.jit_executor == 'legacy': - GRAPH_EXECUTOR = ProfilingMode.LEGACY -elif args.jit_executor == 'profiling': - GRAPH_EXECUTOR = ProfilingMode.PROFILING -elif args.jit_executor == 'simple': - GRAPH_EXECUTOR = ProfilingMode.SIMPLE -else: - # infer flags based on the default settings - GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() - -RERUN_DISABLED_TESTS = args.rerun_disabled_tests - -SLOW_TESTS_FILE = args.import_slow_tests -DISABLED_TESTS_FILE = args.import_disabled_tests -LOG_SUFFIX = args.log_suffix -RUN_PARALLEL = args.run_parallel -TEST_BAILOUTS = args.test_bailouts -USE_PYTEST = args.use_pytest -PYTEST_SINGLE_TEST = args.pytest_single_test -TEST_DISCOVER = args.discover_tests -TEST_IN_SUBPROCESS = args.subprocess -TEST_SAVE_XML = args.save_xml -REPEAT_COUNT = args.repeat -SEED = args.seed -SHOWLOCALS = args.showlocals -if not getattr(expecttest, "ACCEPT", False): - expecttest.ACCEPT = args.accept -UNITTEST_ARGS = [sys.argv[0]] + remaining -torch.manual_seed(SEED) + def run_unittest_help(argv): + unittest.main(argv=argv) + + if '-h' in sys.argv or '--help' in sys.argv: + help_thread = threading.Thread(target=run_unittest_help, args=(sys.argv,)) + help_thread.start() + help_thread.join() + + args, remaining = parser.parse_known_args() + if args.jit_executor == 'legacy': + GRAPH_EXECUTOR = ProfilingMode.LEGACY + elif args.jit_executor == 'profiling': + GRAPH_EXECUTOR = ProfilingMode.PROFILING + elif args.jit_executor == 'simple': + GRAPH_EXECUTOR = ProfilingMode.SIMPLE + else: + # infer flags based on the default settings + GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() + + RERUN_DISABLED_TESTS = args.rerun_disabled_tests + + SLOW_TESTS_FILE = args.import_slow_tests + DISABLED_TESTS_FILE = args.import_disabled_tests + LOG_SUFFIX = args.log_suffix + RUN_PARALLEL = args.run_parallel + TEST_BAILOUTS = args.test_bailouts + USE_PYTEST = args.use_pytest + PYTEST_SINGLE_TEST = args.pytest_single_test + TEST_DISCOVER = args.discover_tests + TEST_IN_SUBPROCESS = args.subprocess + TEST_SAVE_XML = args.save_xml + REPEAT_COUNT = args.repeat + SEED = args.seed + SHOWLOCALS = args.showlocals + if not getattr(expecttest, "ACCEPT", False): + expecttest.ACCEPT = args.accept + UNITTEST_ARGS = [sys.argv[0]] + remaining + torch.manual_seed(SEED) # CI Prefix path used only on CI environment -CI_TEST_PREFIX = str(Path(os.getcwd())) -CI_PT_ROOT = str(Path(os.getcwd()).parent) -CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) + CI_TEST_PREFIX = str(Path(os.getcwd())) + CI_PT_ROOT = str(Path(os.getcwd()).parent) + CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch")) def wait_for_process(p, timeout=None): try: @@ -1139,7 +1182,9 @@ def lint_test_case_extension(suite): return succeed -def get_report_path(argv=UNITTEST_ARGS, pytest=False): +def get_report_path(argv=None, pytest=False): + if argv is None: + argv = UNITTEST_ARGS test_filename = sanitize_test_filename(argv[0]) test_report_path = TEST_SAVE_XML + LOG_SUFFIX test_report_path = os.path.join(test_report_path, test_filename) @@ -1190,7 +1235,11 @@ def pytest_collection_finish(self, session): return test_collector_plugin.tests -def run_tests(argv=UNITTEST_ARGS): +def run_tests(argv=None): + parse_cmd_line_args() + if argv is None: + argv = UNITTEST_ARGS + # import test files. if SLOW_TESTS_FILE: if os.path.exists(SLOW_TESTS_FILE): @@ -1756,6 +1805,7 @@ def decorator(fn): if not isinstance(fn, type): @wraps(fn) def wrapper(*args, **kwargs): + assert GRAPH_EXECUTOR if GRAPH_EXECUTOR == ProfilingMode.LEGACY: raise unittest.SkipTest(msg) else: @@ -2386,7 +2436,19 @@ def get_function_arglist(func): return inspect.getfullargspec(func).args -def set_rng_seed(seed): +def set_rng_seed(seed=None): + if seed is None: + if SEED is not None: + seed = SEED + else: + # Can't assert here: this function is called by TestCase.setUp() and some out of tree tests inherit from that class. + # So just print a warning and hardcode the seed. + seed = 1234 + msg = ("set_rng_seed() was called without providing a seed and the command line " + f"arguments haven't been parsed so the seed will be set to {seed}. " + "To remove this warning make sure your test is run via run_tests() or " + "parse_cmd_line_args() is called before set_rng_seed() is called.") + warnings.warn(msg) torch.manual_seed(seed) random.seed(seed) if TEST_NUMPY: @@ -3409,7 +3471,7 @@ def run(self, result=None): def setUp(self): check_if_enable(self) - set_rng_seed(SEED) + set_rng_seed() # Save global check sparse tensor invariants state that can be # restored from tearDown: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 28b761a37d58..28cd7efc3226 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -137,16 +137,18 @@ def eq(value, other): f = Foo(10) f.bar = 1 -foo_cpu_tensor = Foo(torch.randn(3, 3)) +# Defer instantiation until the seed is set so that randn() returns the same +# values in all processes. +def create_collectives_object_test_list(): + return [ + {"key1": 3, "key2": 4, "key3": {"nested": True}}, + f, + Foo(torch.randn(3, 3)), + "foo", + [1, 2, True, "string", [4, 5, "nested"]], + ] -COLLECTIVES_OBJECT_TEST_LIST = [ - {"key1": 3, "key2": 4, "key3": {"nested": True}}, - f, - foo_cpu_tensor, - "foo", - [1, 2, True, "string", [4, 5, "nested"]], -] # Allowlist of distributed backends where profiling collectives is supported. PROFILING_SUPPORTED_BACKENDS = [ @@ -396,12 +398,6 @@ def forward(self, x): return F.relu(self.lin1(x)) -DDP_NET = Net() -BN_NET = BatchNormNet() -BN_NET_NO_AFFINE = BatchNormNet(affine=False) -ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) - - def get_timeout(test_id): test_name = test_id.split(".")[-1] if test_name in CUSTOMIZED_TIMEOUT: @@ -594,12 +590,13 @@ def destroy_pg_upon_exit(self) -> bool: return False @classmethod - def _run(cls, rank, test_name, file_name, pipe, **kwargs): + def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): if BACKEND == "nccl" and not torch.cuda.is_available(): sys.exit(TEST_SKIPS["no_cuda"].exit_code) self = cls(test_name) self.rank = rank self.file_name = file_name + self.seed = seed if torch.cuda.is_available() and torch.cuda.device_count() < int( self.world_size @@ -4287,7 +4284,7 @@ def _test_DistributedDataParallel( # as baseline # cpu training setup - model = DDP_NET + model = Net() # single gpu training setup model_gpu = copy.deepcopy(model) @@ -4342,7 +4339,7 @@ def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False): _group, _group_id, rank = self._init_global_test() # cpu training setup - model_base = DDP_NET + model_base = Net() # DDP-CPU training setup model_DDP = copy.deepcopy(model_base) @@ -5491,7 +5488,7 @@ def test_DistributedDataParallel(self): def _test_DistributedDataParallel_with_amp(self, grad_is_view=False): torch.manual_seed(31415) # Creates model and optimizer in default precision - model = copy.deepcopy(DDP_NET).cuda() + model = Net().cuda() optimizer = torch.optim.SGD(model.parameters(), lr=0.03) # Creates a GradScaler once at the beginning of training. @@ -5576,7 +5573,7 @@ def _test_DistributedDataParallel_SyncBatchNorm( # as baseline # cpu training setup - model = BN_NET if affine else BN_NET_NO_AFFINE + model = BatchNormNet() if affine else BatchNormNet(affine=False) # single gpu training setup model_gpu = copy.deepcopy(model) @@ -5626,6 +5623,7 @@ def _test_DistributedDataParallel_SyncBatchNorm( def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view): learning_rate = 0.03 + DDP_NET = Net() net = torch.nn.parallel.DistributedDataParallel( copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank], @@ -5692,7 +5690,7 @@ def _test_post_localSGD_optimizer_step_reload( learning_rate = 0.03 net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel( - copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank] + Net().cuda(), device_ids=[self.rank] ) averager = create_averager() @@ -5842,7 +5840,7 @@ def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format( bs_offset = int(rank * 2) global_bs = int(num_processes * 2) - model = ONLY_SBN_NET + model = nn.SyncBatchNorm(2, momentum=0.99) model_gpu = copy.deepcopy(model).cuda(rank) model_DDP = nn.parallel.DistributedDataParallel( model_gpu, device_ids=[rank] @@ -6052,6 +6050,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self): def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( self, ): + ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99) _group, _group_id, rank = self._init_global_test() model = nn.parallel.DistributedDataParallel( ONLY_SBN_NET.cuda(rank), device_ids=[rank] @@ -6119,7 +6118,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): def test_DistributedDataParallel_SyncBatchNorm_half(self): _group, _group_id, rank = self._init_global_test() - model = copy.deepcopy(BN_NET) + model = BatchNormNet() model = model.half() model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = nn.parallel.DistributedDataParallel( @@ -6135,7 +6134,7 @@ def test_DistributedDataParallel_SyncBatchNorm_half(self): def _test_ddp_logging_data(self, is_gpu): rank = dist.get_rank() - model_DDP = copy.deepcopy(DDP_NET) + model_DDP = Net() if is_gpu: model_DDP = nn.parallel.DistributedDataParallel( model_DDP.cuda(rank), device_ids=[rank] @@ -6411,7 +6410,7 @@ def test_ddp_logging_data_gpu(self): BACKEND == "nccl", "nccl does not support DDP on CPU models" ) def test_static_graph_api_cpu(self): - model_DDP = nn.parallel.DistributedDataParallel(DDP_NET) + model_DDP = nn.parallel.DistributedDataParallel(Net()) expected_err = "should be called before training loop starts" with self.assertRaisesRegex(RuntimeError, expected_err): local_bs = 2 @@ -6644,7 +6643,7 @@ def validate_global_samples(local_num_samples): def _test_allgather_object(self, subgroup=None): # Only set device for NCCL backend since it must use GPUs. - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() backend = os.environ["BACKEND"] if backend == "nccl": @@ -6688,7 +6687,7 @@ def test_all_gather_object_subgroup(self): def _test_gather_object(self, pg=None): # Ensure stateful objects can be gathered - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() my_rank = dist.get_rank(pg) backend = os.environ["BACKEND"] @@ -7258,7 +7257,7 @@ def forward(self, x): return x torch.cuda.set_device(self.rank) - model_bn = BN_NET + model_bn = BatchNormNet() model_bn = nn.SyncBatchNorm.convert_sync_batchnorm( copy.deepcopy(model_bn) ).cuda(self.rank) @@ -7554,7 +7553,7 @@ def forward(self, _): loss.backward() def _test_broadcast_object_list(self, group=None): - gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy() + gather_objects = create_collectives_object_test_list() # Only set device for NCCL backend since it must use GPUs. # Case where rank != GPU device. @@ -8278,10 +8277,11 @@ def forward(self, x): @require_backend_is_available({"gloo"}) def test_scatter_object_list(self): src_rank = 0 + collectives_object_test_list = create_collectives_object_test_list() scatter_list = ( - COLLECTIVES_OBJECT_TEST_LIST + collectives_object_test_list if self.rank == src_rank - else [None for _ in COLLECTIVES_OBJECT_TEST_LIST] + else [None for _ in collectives_object_test_list] ) world_size = dist.get_world_size() scatter_list = scatter_list[:world_size] @@ -8294,8 +8294,8 @@ def test_scatter_object_list(self): dist.scatter_object_list(output_obj_list, scatter_list, src=src_rank) self.assertEqual( output_obj_list[0], - COLLECTIVES_OBJECT_TEST_LIST[ - self.rank % len(COLLECTIVES_OBJECT_TEST_LIST) + collectives_object_test_list[ + self.rank % len(collectives_object_test_list) ], ) # Ensure errors are raised upon incorrect arguments. @@ -9981,7 +9981,7 @@ def forward(self, x): "Only Nccl & Gloo backend support DistributedDataParallel", ) def test_sync_bn_logged(self): - model = BN_NET + model = BatchNormNet() rank = self.rank # single gpu training setup model_gpu = model.cuda(rank)