Skip to content

Commit bdd7ca9

Browse files
committed
Add Autoload Testcase for OpenReg
1 parent 9e44889 commit bdd7ca9

File tree

4 files changed

+39
-4
lines changed

4 files changed

+39
-4
lines changed

test/cpp_extensions/open_registration_extension/torch_openreg/setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515

1616
def get_pytorch_dir():
17+
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
1718
import torch
1819

20+
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
21+
1922
return os.path.dirname(os.path.realpath(torch.__file__))
2023

2124

test/run_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,9 +923,33 @@ def run_test_with_openreg(test_module, test_directory, options):
923923
install_dir, return_code = install_cpp_extensions(openreg_dir)
924924
if return_code != 0:
925925
return return_code
926+
927+
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
926928

927929
with extend_python_path([install_dir]):
928930
return run_test(test_module, test_directory, options)
931+
932+
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
933+
934+
935+
def run_test_with_openreg_autoload(test_module, test_directory, options):
936+
# TODO(FFFrog): Will remove this later when windows/macos are supported.
937+
if not IS_LINUX:
938+
return 0
939+
940+
openreg_dir = os.path.join(
941+
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
942+
)
943+
install_dir, return_code = install_cpp_extensions(openreg_dir)
944+
if return_code != 0:
945+
return return_code
946+
947+
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "1"
948+
949+
with extend_python_path([install_dir]):
950+
return run_test(test_module, test_directory, options)
951+
952+
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
929953

930954

931955
def test_distributed(test_module, test_directory, options):
@@ -1256,7 +1280,7 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options):
12561280
"test_autoload_disable": test_autoload_disable,
12571281
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
12581282
"test_openreg": run_test_with_openreg,
1259-
"test_transformers_privateuse1": run_test_with_openreg,
1283+
"test_transformers_privateuse1": run_test_with_openreg_autoload,
12601284
}
12611285

12621286

test/test_openreg.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
import psutil
13-
import torch_openreg # noqa: F401
1413

1514
import torch
1615
from torch.serialization import safe_globals
@@ -24,6 +23,17 @@
2423
)
2524

2625

26+
class TestAutoload(TestCase):
27+
"""Tests of autoloading the OpenReg backend"""
28+
29+
def test_autoload_disable(self):
30+
# Test that the backend is loaded automatically
31+
self.assertFalse(hasattr(torch, "openreg"))
32+
import torch_openreg # noqa: F401
33+
self.assertTrue(torch.openreg.is_available())
34+
35+
36+
2737
class TestPrivateUse1(TestCase):
2838
"""Tests of third-parth device integration mechinasm based PrivateUse1"""
2939

test/test_transformers_privateuse1.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from collections import namedtuple
55
from functools import partial
66

7-
import torch_openreg # noqa: F401
8-
97
import torch
108
from torch.nn.attention import SDPBackend
119
from torch.testing._internal.common_nn import NNTestCase

0 commit comments

Comments
 (0)