Skip to content

Stop parsing command line arguments every time common_utils is imported. #156703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e5b1254
Stop parsing command line arguments every time common_utils is imported.
AnthonyBarbier Jun 24, 2025
adc6560
Use Optional instead of |
AnthonyBarbier Jun 24, 2025
edf4187
Fix test_jit_legacy test
AnthonyBarbier Jun 25, 2025
2a3b177
Move assert to when the variable is actually read
AnthonyBarbier Jun 25, 2025
36da9a9
Fix jit tests
AnthonyBarbier Jun 26, 2025
65210f6
Set seed for distributed_test.py
AnthonyBarbier Jun 26, 2025
36cbf8f
Merge remote-tracking branch 'upstream/main' into argparse
AnthonyBarbier Jul 10, 2025
a4c5ee3
Fix seed setting
AnthonyBarbier Jul 10, 2025
0ece614
Fixing jit and distributed tests
AnthonyBarbier Jul 10, 2025
a9c8899
Clean up
AnthonyBarbier Jul 10, 2025
e6dc730
Fix test_jit_fuser
AnthonyBarbier Jul 10, 2025
a4b833a
Relax checks in set_rng_seed()
AnthonyBarbier Jul 10, 2025
79faf6c
Fix more tests
AnthonyBarbier Jul 16, 2025
6275b8d
Merge remote-tracking branch 'upstream/main' into argparse
AnthonyBarbier Jul 16, 2025
fc93a2a
Fix distributed model initialisation
AnthonyBarbier Jul 16, 2025
b87a0cd
Make seed an optional argument
AnthonyBarbier Jul 16, 2025
786db21
Pass the seed for subclasses too
AnthonyBarbier Jul 17, 2025
6c9730c
Handle seed in NCCLTraceTestBase
AnthonyBarbier Jul 18, 2025
28eba9a
Merge remote-tracking branch 'upstream/main' into argparse
AnthonyBarbier Jul 18, 2025
2412e13
Fix seed in FSDP tests
AnthonyBarbier Jul 21, 2025
5fe1c4b
Fix test_nn.py
AnthonyBarbier Aug 1, 2025
650a70f
Merge remote-tracking branch 'upstream/main' into argparse
AnthonyBarbier Aug 1, 2025
073ee74
Merge remote-tracking branch 'upstream/main' into argparse
AnthonyBarbier Aug 11, 2025
59bf860
Trying to hardcode the seed for distributed tests in _start_processes…
AnthonyBarbier Aug 11, 2025
0663f08
Use logger.warning
AnthonyBarbier Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4280,10 +4280,11 @@ def _run(
test_name: str,
file_name: str,
parent_pipe,
seed: int,
**kwargs,
) -> None:
cls.parent = parent_conn
super()._run(rank, test_name, file_name, parent_pipe)
super()._run(rank, test_name, file_name, parent_pipe, seed)

@property
def local_device(self):
Expand Down
3 changes: 3 additions & 0 deletions test/jit/test_autodiff_subgraph_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
)


assert GRAPH_EXECUTOR is not None


@unittest.skipIf(
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
)
Expand Down
5 changes: 5 additions & 0 deletions test/test_cpp_api_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class TestCppApiParity(common.TestCase):
functional_test_params_map = {}


if __name__ == "__main__":
# The value of the SEED depends on command line arguments so make sure they're parsed
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
common.parse_cmd_line_args()

expected_test_params_dicts = []

for test_params_dicts, test_instance_class in [
Expand Down
7 changes: 7 additions & 0 deletions test/test_expanded_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,13 @@ def filter_supported_tests(t):
return True


if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args

# The value of the SEED depends on command line arguments so make sure they're parsed
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
parse_cmd_line_args()

# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
# These currently use the legacy nn tests
supported_tests = [
Expand Down
11 changes: 10 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@

import torch

if __name__ == '__main__':
from torch.testing._internal.common_utils import parse_cmd_line_args

# The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()

# This is how we include tests located in test/jit/...
# They are included here so that they are invoked when you call `test_jit.py`,
# do not run these test files directly.
Expand Down Expand Up @@ -97,7 +104,7 @@
from torch.testing._internal import jit_utils
from torch.testing._internal.common_jit import check_against_reference
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
skipIfCrossRef, skipIfTorchDynamo
Expand Down Expand Up @@ -158,6 +165,7 @@ def doAutodiffCheck(testname):
if "test_t_" in testname or testname == "test_t":
return False

assert GRAPH_EXECUTOR
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
return False

Expand Down Expand Up @@ -201,6 +209,7 @@ def doAutodiffCheck(testname):
return testname not in test_exceptions


assert GRAPH_EXECUTOR
# TODO: enable TE in PE when all tests are fixed
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)
Expand Down
9 changes: 7 additions & 2 deletions test/test_jit_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
from typing import Optional

import unittest
from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
from torch.testing import FileCheck
from jit.test_models import MnistNet

if __name__ == '__main__':
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()

from test_jit import JitTestCase
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()

@skipIfTorchDynamo("Not a TorchDynamo suitable test")
Expand Down
7 changes: 7 additions & 0 deletions test/test_jit_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from torch.testing import FileCheck
from unittest import skipIf

if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args

# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()

from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \
Expand Down
8 changes: 8 additions & 0 deletions test/test_jit_fuser_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

import sys
sys.argv.append("--jit-executor=legacy")

if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args

# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()

from test_jit_fuser import * # noqa: F403

if __name__ == '__main__':
Expand Down
7 changes: 7 additions & 0 deletions test/test_jit_fuser_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
torch._C._jit_set_profiling_executor(True)
torch._C._get_graph_executor_optimize(True)

if __name__ == "__main__":
from torch.testing._internal.common_utils import parse_cmd_line_args

# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()

from itertools import combinations, permutations, product
from textwrap import dedent

Expand Down
9 changes: 8 additions & 1 deletion test/test_jit_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import sys
sys.argv.append("--jit-executor=legacy")
from test_jit import * # noqa: F403
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests

if __name__ == '__main__':
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
# before instantiating tests.
parse_cmd_line_args()

from test_jit import * # noqa: F403, F401

if __name__ == '__main__':
run_tests()
7 changes: 7 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7643,6 +7643,13 @@ def with_tf32_on(self, test=test, kwargs=kwargs):
else:
add(cuda_test_name, with_tf32_off)

if __name__ == '__main__':
from torch.testing._internal.common_utils import parse_cmd_line_args

# The value of the SEED depends on command line arguments so make sure they're parsed
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
parse_cmd_line_args()

for test_params in module_tests + get_new_module_tests():
# TODO: CUDA is not implemented yet
if 'constructor' not in test_params:
Expand Down
26 changes: 23 additions & 3 deletions torch/testing/_internal/common_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torch._C._autograd import DeviceType
from torch._C._distributed_c10d import _SymmetricMemory
from torch._logging._internal import trace_log
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
find_free_port,
Expand Down Expand Up @@ -671,6 +672,7 @@ def __init__(
if methodName != "runTest":
method_name = methodName
super().__init__(method_name)
self.seed = None
try:
fn = getattr(self, method_name)
setattr(self, method_name, self.join_or_run(fn))
Expand Down Expand Up @@ -715,13 +717,26 @@ def _current_test_name(self) -> str:

def _start_processes(self, proc) -> None:
self.processes = []
# distributed tests don't support setting the seed via the command line so hardcode it here.
hardcoded_seed = 1234
if common_utils.SEED and common_utils.SEED != hardcoded_seed:
msg = ("Distributed tests do not support setting the seed via the command line. "
f"the seed will be reset to its default value ({hardcoded_seed} now")
logger.warning(msg)
common_utils.SEED = hardcoded_seed
for rank in range(int(self.world_size)):
parent_conn, child_conn = torch.multiprocessing.Pipe()
process = proc(
target=self.__class__._run,
name="process " + str(rank),
args=(rank, self._current_test_name(), self.file_name, child_conn),
args=(
rank,
self._current_test_name(),
self.file_name,
child_conn,
),
kwargs={
"seed": common_utils.SEED,
"fake_pg": getattr(self, "fake_pg", False),
},
)
Expand Down Expand Up @@ -775,11 +790,12 @@ def _event_listener(parent_pipe, signal_pipe, rank: int):

@classmethod
def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
) -> None:
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.seed = seed
self.run_test(test_name, parent_pipe)

def run_test(self, test_name: str, parent_pipe) -> None:
Expand All @@ -798,6 +814,9 @@ def run_test(self, test_name: str, parent_pipe) -> None:
# Show full C++ stacktraces when a Python error originating from C++ is raised.
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"

if self.seed is not None:
common_utils.set_rng_seed(self.seed)

# self.id() == e.g. '__main__.TestDistributed.test_get_rank'
# We're retrieving a corresponding test and executing it.
try:
Expand Down Expand Up @@ -1535,14 +1554,15 @@ def world_size(self) -> int:

@classmethod
def _run(
cls, rank: int, test_name: str, file_name: str, parent_pipe, **kwargs
cls, rank: int, test_name: str, file_name: str, parent_pipe, seed: int, **kwargs
) -> None:
trace_log.addHandler(logging.NullHandler())

# The rest is copypasta from MultiProcessTestCase._run
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.seed = seed
self.run_test(test_name, parent_pipe)


Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
get_cycles_per_ms,
set_rng_seed,
TEST_CUDA,
TEST_HPU,
TEST_XPU,
Expand Down Expand Up @@ -1180,7 +1181,7 @@ def run_subtests(self, *args, **kwargs):
return run_subtests(self, *args, **kwargs)

@classmethod
def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[override]
def _run(cls, rank, test_name, file_name, pipe, seed, **kwargs): # type: ignore[override]
self = cls(test_name)
self.rank = rank
self.file_name = file_name
Expand Down Expand Up @@ -1226,6 +1227,7 @@ def _run(cls, rank, test_name, file_name, pipe, **kwargs): # type: ignore[overr
dist.barrier(device_ids=device_ids)

torch._dynamo.reset()
set_rng_seed(seed)
self.run_test(test_name, pipe)
torch._dynamo.reset()

Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import _reduction as _Reduction
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import TestCase, to_gpu, freeze_rng_state, is_iterable, \
gradcheck, gradgradcheck, set_default_dtype, skipIfTorchDynamo, TEST_WITH_ROCM
from torch.testing._internal.common_cuda import TEST_CUDA, SM90OrLater
Expand Down Expand Up @@ -1078,6 +1079,7 @@ def unsqueeze_inp(inp):


def get_new_module_tests():
assert common_utils.SEED is not None, "Make sure the seed is set before calling get_new_module_tests()"
new_module_tests = [
poissonnllloss_no_reduce_test(),
bceloss_no_reduce_test(),
Expand Down
Loading
Loading