Skip to content

Commit c6b901b

Browse files
committed
Add Autoload Testcase for OpenReg
1 parent 9e44889 commit c6b901b

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

test/cpp_extensions/open_registration_extension/torch_openreg/setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515

1616
def get_pytorch_dir():
17+
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
1718
import torch
1819

20+
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
21+
1922
return os.path.dirname(os.path.realpath(torch.__file__))
2023

2124

test/run_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,32 @@ def run_test_with_openreg(test_module, test_directory, options):
924924
if return_code != 0:
925925
return return_code
926926

927+
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0"
928+
929+
with extend_python_path([install_dir]):
930+
return run_test(test_module, test_directory, options)
931+
932+
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
933+
934+
935+
def run_test_with_openreg_autoload(test_module, test_directory, options):
936+
# TODO(FFFrog): Will remove this later when windows/macos are supported.
937+
if not IS_LINUX:
938+
return 0
939+
940+
openreg_dir = os.path.join(
941+
test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg"
942+
)
943+
install_dir, return_code = install_cpp_extensions(openreg_dir)
944+
if return_code != 0:
945+
return return_code
946+
947+
os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "1"
948+
927949
with extend_python_path([install_dir]):
928950
return run_test(test_module, test_directory, options)
951+
952+
os.environ.pop("TORCH_DEVICE_BACKEND_AUTOLOAD")
929953

930954

931955
def test_distributed(test_module, test_directory, options):
@@ -1256,6 +1280,7 @@ def run_ci_sanity_check(test: ShardedTest, test_directory, options):
12561280
"test_autoload_disable": test_autoload_disable,
12571281
"test_cpp_extensions_open_device_registration": run_test_with_openreg,
12581282
"test_openreg": run_test_with_openreg,
1283+
"test_openreg_autoload": run_test_with_openreg_autoload,
12591284
"test_transformers_privateuse1": run_test_with_openreg,
12601285
}
12611286

test/test_openreg_autoload.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
from torch.testing._internal.common_utils import TestCase, run_tests
3+
4+
5+
class TestAutoload(TestCase):
6+
"""Tests of autoloading the OpenReg backend"""
7+
8+
def test_autoload_enable(self):
9+
# Test that the backend is loaded automatically
10+
self.assertTrue(torch.openreg.is_available())
11+
12+
13+
if __name__ == "__main__":
14+
run_tests()

0 commit comments

Comments
 (0)