Skip to content

Commit 5ee0ad5

Browse files
committed
Fix some issues
1 parent a9194c6 commit 5ee0ad5

File tree

11 files changed

+205
-176
lines changed

11 files changed

+205
-176
lines changed

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,10 @@ add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
99
target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg)
1010
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
1111

12-
if (WIN32)
13-
target_compile_definitions(${LIBRARY_NAME} PRIVATE OPENREG_EXPORTS)
14-
endif()
12+
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
1513

16-
if(WIN32)
17-
install(TARGETS ${LIBRARY_NAME}
18-
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
19-
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
20-
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
21-
)
22-
else()
23-
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
24-
endif()
14+
install(TARGETS ${LIBRARY_NAME}
15+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
16+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
17+
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
18+
)

test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ int device_count_impl() {
3030
return count;
3131
}
3232

33-
OPENREG_API c10::DeviceIndex device_count() noexcept {
33+
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
3434
// initialize number of devices only once
3535
static int count = []() {
3636
try {
@@ -49,17 +49,17 @@ OPENREG_API c10::DeviceIndex device_count() noexcept {
4949
return static_cast<c10::DeviceIndex>(count);
5050
}
5151

52-
OPENREG_API c10::DeviceIndex current_device() {
52+
OPENREG_EXPORT c10::DeviceIndex current_device() {
5353
c10::DeviceIndex cur_device = -1;
5454
GetDevice(&cur_device);
5555
return cur_device;
5656
}
5757

58-
OPENREG_API void set_device(c10::DeviceIndex device) {
58+
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
5959
SetDevice(device);
6060
}
6161

62-
OPENREG_API DeviceIndex ExchangeDevice(DeviceIndex device) {
62+
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
6363
int current_device = -1;
6464
orGetDevice(&current_device);
6565

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
#pragma once
22

33
#ifdef _WIN32
4-
#ifdef OPENREG_EXPORTS
5-
#define OPENREG_API __declspec(dllexport)
6-
#else
7-
#define OPENREG_API __declspec(dllimport)
8-
#endif
4+
#define OPENREG_EXPORT __declspec(dllexport)
95
#else
10-
#define OPENREG_API
6+
#define OPENREG_EXPORT __attribute__((visibility("default")))
117
#endif
128

139
#include <c10/core/Device.h>
@@ -17,10 +13,10 @@
1713

1814
namespace c10::openreg {
1915

20-
OPENREG_API c10::DeviceIndex device_count() noexcept;
21-
OPENREG_API c10::DeviceIndex current_device();
22-
OPENREG_API void set_device(c10::DeviceIndex device);
16+
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
17+
OPENREG_EXPORT c10::DeviceIndex current_device();
18+
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
2319

24-
OPENREG_API DeviceIndex ExchangeDevice(DeviceIndex device);
20+
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
2521

2622
} // namespace c10::openreg

test/cpp_extensions/open_registration_extension/torch_openreg/setup.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,19 @@
1919
RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv)
2020

2121

22+
def check_env_flag(name, default = ""):
23+
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
24+
25+
26+
if "CMAKE_BUILD_TYPE" not in os.environ:
27+
if check_env_flag("DEBUG"):
28+
os.environ["CMAKE_BUILD_TYPE"] = "Debug"
29+
elif check_env_flag("REL_WITH_DEB_INFO"):
30+
os.environ["CMAKE_BUILD_TYPE"] = "RelWithDebInfo"
31+
else:
32+
os.environ["CMAKE_BUILD_TYPE"] = "Release"
33+
34+
2235
def make_relative_rpath_args(path):
2336
if IS_DARWIN:
2437
return ["-Wl,-rpath,@loader_path/" + path]
@@ -54,7 +67,7 @@ def build_deps():
5467
".",
5568
"--target",
5669
"install",
57-
"--config", "Release",
70+
"--config", os.environ["CMAKE_BUILD_TYPE"],
5871
"--",
5972
]
6073

@@ -83,11 +96,44 @@ def run(self):
8396
def main():
8497
if not RUN_BUILD_DEPS:
8598
build_deps()
86-
87-
if sys.platform == "win32":
88-
extra_compile_args = ["/W3"]
99+
100+
if IS_WINDOWS:
101+
# /NODEFAULTLIB makes sure we only link to DLL runtime
102+
# and matches the flags set for protobuf and ONNX
103+
extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"]
104+
# /MD links against DLL runtime
105+
# and matches the flags set for protobuf and ONNX
106+
# /EHsc is about standard C++ exception handling
107+
extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"]
89108
else:
90-
extra_compile_args = ["-g", "-Wall", "-Werror"]
109+
extra_link_args = []
110+
extra_compile_args = [
111+
"-Wall",
112+
"-Wextra",
113+
"-Wno-strict-overflow",
114+
"-Wno-unused-parameter",
115+
"-Wno-missing-field-initializers",
116+
"-Wno-unknown-pragmas",
117+
# Python 2.6 requires -fno-strict-aliasing, see
118+
# http://legacy.python.org/dev/peps/pep-3123/
119+
# We also depend on it in our code (even Python 3).
120+
"-fno-strict-aliasing",
121+
]
122+
123+
if os.environ["CMAKE_BUILD_TYPE"] == "Debug":
124+
if IS_WINDOWS:
125+
extra_compile_args += ["/Z7"]
126+
extra_link_args += ["/DEBUG:FULL"]
127+
else:
128+
extra_compile_args += ["-O0", "-g"]
129+
extra_link_args += ["-O0", "-g"]
130+
elif os.environ["CMAKE_BUILD_TYPE"] == "RelWithDebInfo":
131+
if IS_WINDOWS:
132+
extra_compile_args += ["/Z7"]
133+
extra_link_args += ["/DEBUG:FULL"]
134+
else:
135+
extra_compile_args += ["-g"]
136+
extra_link_args += ["-g"]
91137

92138
ext_modules = [
93139
Extension(

test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,14 @@ file(GLOB_RECURSE SOURCE_FILES
44
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
55
)
66

7-
if(WIN32)
8-
add_library(${LIBRARY_NAME} STATIC ${SOURCE_FILES})
9-
else()
10-
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
11-
endif()
7+
add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})
128

139
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
1410

15-
if(WIN32)
16-
install(TARGETS ${LIBRARY_NAME}
17-
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
18-
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
19-
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
20-
)
21-
else()
22-
install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
23-
endif()
11+
set(CMAKE_CXX_VISIBILITY_PRESET hidden)
12+
13+
install(TARGETS ${LIBRARY_NAME}
14+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
15+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
16+
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
17+
)

test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
#include <cstddef>
44

5+
#ifdef _WIN32
6+
#define OPENREG_EXPORT __declspec(dllexport)
7+
#else
8+
#define OPENREG_EXPORT __attribute__((visibility("default")))
9+
#endif
10+
511
#ifdef __cplusplus
612
extern "C" {
713
#endif
@@ -28,19 +34,19 @@ struct orPointerAttributes {
2834
size_t size;
2935
};
3036

31-
orError_t orMalloc(void** devPtr, size_t size);
32-
orError_t orFree(void* devPtr);
33-
orError_t orMallocHost(void** hostPtr, size_t size);
34-
orError_t orFreeHost(void* hostPtr);
35-
orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
36-
orError_t orMemoryUnprotect(void* devPtr);
37-
orError_t orMemoryProtect(void* devPtr);
37+
OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size);
38+
OPENREG_EXPORT orError_t orFree(void* devPtr);
39+
OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size);
40+
OPENREG_EXPORT orError_t orFreeHost(void* hostPtr);
41+
OPENREG_EXPORT orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind);
42+
OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr);
43+
OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr);
3844

39-
orError_t orGetDeviceCount(int* count);
40-
orError_t orSetDevice(int device);
41-
orError_t orGetDevice(int* device);
45+
OPENREG_EXPORT orError_t orGetDeviceCount(int* count);
46+
OPENREG_EXPORT orError_t orSetDevice(int device);
47+
OPENREG_EXPORT orError_t orGetDevice(int* device);
4248

43-
orError_t orPointerGetAttributes(
49+
OPENREG_EXPORT orError_t orPointerGetAttributes(
4450
orPointerAttributes* attributes,
4551
const void* ptr);
4652

test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py

Lines changed: 3 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,13 @@
1-
import ctypes
2-
import glob
3-
import os
4-
import platform
51
import sys
6-
import textwrap
7-
import sysconfig
8-
92
import torch
103

11-
if sys.platform == "win32":
12-
13-
def _load_dll_libraries() -> None:
14-
15-
py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin")
16-
th_dll_path = os.path.join(os.path.dirname(__file__), "lib")
17-
usebase_path = os.path.join(sysconfig.get_config_var("userbase"), "Library", "bin")
18-
py_root_bin_path = os.path.join(sys.exec_prefix, "bin")
19-
20-
# When users create a virtualenv that inherits the base environment,
21-
# we will need to add the corresponding library directory into
22-
# DLL search directories. Otherwise, it will rely on `PATH` which
23-
# is dependent on user settings.
24-
if sys.exec_prefix != sys.base_exec_prefix:
25-
base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin")
26-
else:
27-
base_py_dll_path = ""
28-
29-
dll_paths = [
30-
p
31-
for p in (
32-
th_dll_path,
33-
py_dll_path,
34-
base_py_dll_path,
35-
usebase_path,
36-
py_root_bin_path,
37-
)
38-
if os.path.exists(p)
39-
]
40-
41-
42-
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
43-
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
44-
prev_error_mode = kernel32.SetErrorMode(0x0001)
45-
46-
kernel32.LoadLibraryW.restype = ctypes.c_void_p
47-
if with_load_library_flags:
48-
kernel32.LoadLibraryExW.restype = ctypes.c_void_p
49-
50-
for dll_path in dll_paths:
51-
os.add_dll_directory(dll_path)
52-
53-
try:
54-
ctypes.CDLL("vcruntime140.dll")
55-
ctypes.CDLL("msvcp140.dll")
56-
if platform.machine() != "ARM64":
57-
ctypes.CDLL("vcruntime140_1.dll")
58-
except OSError:
59-
print(
60-
textwrap.dedent(
61-
"""
62-
Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
63-
It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe
64-
"""
65-
).strip()
66-
)
67-
68-
dlls = glob.glob(os.path.join(th_dll_path, "*.dll"))
69-
path_patched = False
70-
for dll in dlls:
71-
is_loaded = False
72-
if with_load_library_flags:
73-
res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
74-
last_error = ctypes.get_last_error()
75-
if res is None and last_error != 126:
76-
err = ctypes.WinError(last_error)
77-
err.strerror += (
78-
f' Error loading "{dll}" or one of its dependencies.'
79-
)
80-
raise err
81-
elif res is not None:
82-
is_loaded = True
83-
if not is_loaded:
84-
if not path_patched:
85-
os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]])
86-
path_patched = True
87-
res = kernel32.LoadLibraryW(dll)
88-
if res is None:
89-
err = ctypes.WinError(ctypes.get_last_error())
90-
err.strerror += (
91-
f' Error loading "{dll}" or one of its dependencies.'
92-
)
93-
raise err
94-
95-
kernel32.SetErrorMode(prev_error_mode)
964

5+
if sys.platform == "win32":
6+
from ._utils import _load_dll_libraries
977
_load_dll_libraries()
988
del _load_dll_libraries
999

10+
10011
import torch_openreg._C # type: ignore[misc]
10112
import torch_openreg.openreg
10213

0 commit comments

Comments
 (0)