Skip to content

Commit 356ac31

Browse files
Revert "Stop parsing command line arguments every time common_utils is imported. (#156703)"
This reverts commit 310f901. Reverted #156703 on behalf of https://github.com/izaitsevfb due to breaking tests internally with `assert common_utils.SEED is not None` ([comment](#156703 (comment)))
1 parent d4109a0 commit 356ac31

16 files changed

+113
-259
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4280,11 +4280,10 @@ def _run(
42804280
test_name: str,
42814281
file_name: str,
42824282
parent_pipe,
4283-
seed: int,
42844283
**kwargs,
42854284
) -> None:
42864285
cls.parent = parent_conn
4287-
super()._run(rank, test_name, file_name, parent_pipe, seed)
4286+
super()._run(rank, test_name, file_name, parent_pipe)
42884287

42894288
@property
42904289
def local_device(self):

test/jit/test_autodiff_subgraph_slicing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
)
2828

2929

30-
assert GRAPH_EXECUTOR is not None
31-
32-
3330
@unittest.skipIf(
3431
GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients"
3532
)

test/test_cpp_api_parity.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,6 @@ class TestCppApiParity(common.TestCase):
3535
functional_test_params_map = {}
3636

3737

38-
if __name__ == "__main__":
39-
# The value of the SEED depends on command line arguments so make sure they're parsed
40-
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
41-
common.parse_cmd_line_args()
42-
4338
expected_test_params_dicts = []
4439

4540
for test_params_dicts, test_instance_class in [

test/test_expanded_weights.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,13 +1008,6 @@ def filter_supported_tests(t):
10081008
return True
10091009

10101010

1011-
if __name__ == "__main__":
1012-
from torch.testing._internal.common_utils import parse_cmd_line_args
1013-
1014-
# The value of the SEED depends on command line arguments so make sure they're parsed
1015-
# before instantiating tests because some modules as part of get_new_module_tests() will call torch.randn
1016-
parse_cmd_line_args()
1017-
10181011
# TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests
10191012
# These currently use the legacy nn tests
10201013
supported_tests = [

test/test_jit.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,6 @@
33

44
import torch
55

6-
if __name__ == '__main__':
7-
from torch.testing._internal.common_utils import parse_cmd_line_args
8-
9-
# The value of GRAPH_EXECUTOR and SEED depend on command line arguments so make sure they're parsed
10-
# before instantiating tests.
11-
parse_cmd_line_args()
12-
136
# This is how we include tests located in test/jit/...
147
# They are included here so that they are invoked when you call `test_jit.py`,
158
# do not run these test files directly.
@@ -104,7 +97,7 @@
10497
from torch.testing._internal import jit_utils
10598
from torch.testing._internal.common_jit import check_against_reference
10699
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \
107-
GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \
100+
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, \
108101
TestCase, freeze_rng_state, slowTest, TemporaryFileName, \
109102
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
110103
skipIfCrossRef, skipIfTorchDynamo
@@ -165,7 +158,6 @@ def doAutodiffCheck(testname):
165158
if "test_t_" in testname or testname == "test_t":
166159
return False
167160

168-
assert GRAPH_EXECUTOR
169161
if GRAPH_EXECUTOR == ProfilingMode.SIMPLE:
170162
return False
171163

@@ -209,7 +201,6 @@ def doAutodiffCheck(testname):
209201
return testname not in test_exceptions
210202

211203

212-
assert GRAPH_EXECUTOR
213204
# TODO: enable TE in PE when all tests are fixed
214205
torch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING)
215206
torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY)

test/test_jit_autocast.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,12 @@
55
from typing import Optional
66

77
import unittest
8+
from test_jit import JitTestCase
89
from torch.testing._internal.common_cuda import TEST_CUDA
9-
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo
10+
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
1011
from torch.testing import FileCheck
1112
from jit.test_models import MnistNet
1213

13-
if __name__ == '__main__':
14-
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
15-
# before instantiating tests.
16-
parse_cmd_line_args()
17-
18-
from test_jit import JitTestCase
1914
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
2015

2116
@skipIfTorchDynamo("Not a TorchDynamo suitable test")

test/test_jit_fuser.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,6 @@
99
from torch.testing import FileCheck
1010
from unittest import skipIf
1111

12-
if __name__ == "__main__":
13-
from torch.testing._internal.common_utils import parse_cmd_line_args
14-
15-
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
16-
# before instantiating tests.
17-
parse_cmd_line_args()
18-
1912
from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
2013
enable_profiling_mode_for_profiling_tests, IS_WINDOWS, TemporaryDirectoryName, shell
2114
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \

test/test_jit_fuser_legacy.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,6 @@
22

33
import sys
44
sys.argv.append("--jit-executor=legacy")
5-
6-
if __name__ == "__main__":
7-
from torch.testing._internal.common_utils import parse_cmd_line_args
8-
9-
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
10-
# before instantiating tests.
11-
parse_cmd_line_args()
12-
135
from test_jit_fuser import * # noqa: F403
146

157
if __name__ == '__main__':

test/test_jit_fuser_te.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@
2222
torch._C._jit_set_profiling_executor(True)
2323
torch._C._get_graph_executor_optimize(True)
2424

25-
if __name__ == "__main__":
26-
from torch.testing._internal.common_utils import parse_cmd_line_args
27-
28-
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
29-
# before instantiating tests.
30-
parse_cmd_line_args()
31-
3225
from itertools import combinations, permutations, product
3326
from textwrap import dedent
3427

test/test_jit_legacy.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,7 @@
22

33
import sys
44
sys.argv.append("--jit-executor=legacy")
5-
from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests
6-
7-
if __name__ == '__main__':
8-
# The value of GRAPH_EXECUTOR depends on command line arguments so make sure they're parsed
9-
# before instantiating tests.
10-
parse_cmd_line_args()
11-
12-
from test_jit import * # noqa: F403, F401
5+
from test_jit import * # noqa: F403
136

147
if __name__ == '__main__':
158
run_tests()

0 commit comments

Comments
 (0)