diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt index 73163b8cb1ae..c1cc0eeeb3b1 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/CMakeLists.txt @@ -4,28 +4,29 @@ project(TORCH_OPENREG CXX C) include(GNUInstallDirs) include(CheckCXXCompilerFlag) -include(CMakeDependentOption) - -set(CMAKE_SKIP_BUILD_RPATH FALSE) -set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) -set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) -set(CMAKE_INSTALL_RPATH "$ORIGIN/lib/:$ORIGIN/") - -set(LINUX TRUE) -set(CMAKE_INSTALL_MESSAGE NEVER) -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_CXX_STANDARD 17) set(CMAKE_C_STANDARD 11) set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +set(CMAKE_CXX_VISIBILITY_PRESET hidden) + +if(APPLE) + set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path") +elseif(UNIX) + set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN") +elseif(WIN32) + set(CMAKE_INSTALL_RPATH "") +endif() set(CMAKE_INSTALL_LIBDIR lib) - -add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=1) +set(CMAKE_INSTALL_MESSAGE NEVER) set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch) find_package(Torch REQUIRED) -include_directories(${PYTORCH_INSTALL_DIR}/include) if(DEFINED PYTHON_INCLUDE_DIR) include_directories(${PYTHON_INCLUDE_DIR}) @@ -33,6 +34,8 @@ else() message(FATAL_ERROR "Cannot find Python directory") endif() +include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake) + add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg) add_subdirectory(${PROJECT_SOURCE_DIR}/csrc) add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/cmake/TorchPythonTargets.cmake b/test/cpp_extensions/open_registration_extension/torch_openreg/cmake/TorchPythonTargets.cmake new file mode 100644 index 000000000000..181dee20e0bb --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/cmake/TorchPythonTargets.cmake @@ -0,0 +1,20 @@ +if(WIN32) + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib") +else() + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so") +endif() + +add_library(torch_python SHARED IMPORTED) + +set_target_properties(torch_python PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include" + INTERFACE_LINK_LIBRARIES "c10;torch_cpu" + IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}" +) + +add_library(torch_python_library INTERFACE IMPORTED) + +set_target_properties(torch_python_library PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "\$" + INTERFACE_LINK_LIBRARIES "\$;\$" +) \ No newline at end of file diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt index 077f4cf3b640..e2ae2b3f3667 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/CMakeLists.txt @@ -6,7 +6,11 @@ file(GLOB_RECURSE SOURCE_FILES add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) -target_link_libraries(${LIBRARY_NAME} PRIVATE openreg torch_cpu) +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg) target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) -install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp index 240c2d8ce1aa..6b928f4ad9cc 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp @@ -30,7 +30,7 @@ int device_count_impl() { return count; } -c10::DeviceIndex device_count() noexcept { +OPENREG_EXPORT c10::DeviceIndex device_count() noexcept { // initialize number of devices only once static int count = []() { try { @@ -49,17 +49,17 @@ c10::DeviceIndex device_count() noexcept { return static_cast(count); } -c10::DeviceIndex current_device() { +OPENREG_EXPORT c10::DeviceIndex current_device() { c10::DeviceIndex cur_device = -1; GetDevice(&cur_device); return cur_device; } -void set_device(c10::DeviceIndex device) { +OPENREG_EXPORT void set_device(c10::DeviceIndex device) { SetDevice(device); } -DeviceIndex ExchangeDevice(DeviceIndex device) { +OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) { int current_device = -1; orGetDevice(¤t_device); diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h index b6b991ff6d3a..8d8e9cd1e302 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.h @@ -1,5 +1,11 @@ #pragma once +#ifdef _WIN32 + #define OPENREG_EXPORT __declspec(dllexport) +#else + #define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + #include #include @@ -7,10 +13,10 @@ namespace c10::openreg { -c10::DeviceIndex device_count() noexcept; -DeviceIndex current_device(); -void set_device(c10::DeviceIndex device); +OPENREG_EXPORT c10::DeviceIndex device_count() noexcept; +OPENREG_EXPORT c10::DeviceIndex current_device(); +OPENREG_EXPORT void set_device(c10::DeviceIndex device); -DeviceIndex ExchangeDevice(DeviceIndex device); +OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device); } // namespace c10::openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py index 07d31e73d76b..48124c079e48 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/setup.py @@ -1,5 +1,6 @@ import multiprocessing import os +import platform import shutil import subprocess import sys @@ -9,10 +10,23 @@ from setuptools import Extension, find_packages, setup +# Env Variables +IS_DARWIN = platform.system() == "Darwin" +IS_WINDOWS = platform.system() == "Windows" + BASE_DIR = os.path.dirname(os.path.realpath(__file__)) RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv) +def make_relative_rpath_args(path): + if IS_DARWIN: + return ["-Wl,-rpath,@loader_path/" + path] + elif IS_WINDOWS: + return [] + else: + return ["-Wl,-rpath,$ORIGIN/" + path] + + def get_pytorch_dir(): import torch @@ -39,9 +53,15 @@ def build_deps(): ".", "--target", "install", + "--config", + "Release", "--", ] - build_args += ["-j", str(multiprocessing.cpu_count())] + + if IS_WINDOWS: + build_args += ["/m:" + str(multiprocessing.cpu_count())] + else: + build_args += ["-j", str(multiprocessing.cpu_count())] command = ["cmake"] + build_args subprocess.check_call(command, cwd=build_dir, env=os.environ) @@ -64,19 +84,51 @@ def main(): if not RUN_BUILD_DEPS: build_deps() + if IS_WINDOWS: + # /NODEFAULTLIB makes sure we only link to DLL runtime + # and matches the flags set for protobuf and ONNX + extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [ + *make_relative_rpath_args("lib") + ] + # /MD links against DLL runtime + # and matches the flags set for protobuf and ONNX + # /EHsc is about standard C++ exception handling + extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"] + else: + extra_link_args = [*make_relative_rpath_args("lib")] + extra_compile_args = [ + "-Wall", + "-Wextra", + "-Wno-strict-overflow", + "-Wno-unused-parameter", + "-Wno-missing-field-initializers", + "-Wno-unknown-pragmas", + # Python 2.6 requires -fno-strict-aliasing, see + # http://legacy.python.org/dev/peps/pep-3123/ + # We also depend on it in our code (even Python 3). + "-fno-strict-aliasing", + ] + ext_modules = [ Extension( name="torch_openreg._C", sources=["torch_openreg/csrc/stub.c"], language="c", - extra_compile_args=["-g", "-Wall", "-Werror"], + extra_compile_args=extra_compile_args, libraries=["torch_bindings"], library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")], - extra_link_args=["-Wl,-rpath,$ORIGIN/lib"], + extra_link_args=extra_link_args, ) ] - package_data = {"torch_openreg": ["lib/*.so*"]} + package_data = { + "torch_openreg": [ + "lib/*.so*", + "lib/*.dylib*", + "lib/*.dll", + "lib/*.lib", + ] + } setup( packages=find_packages(), diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt index 7fec109eeb1c..5450b49be164 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/CMakeLists.txt @@ -8,4 +8,8 @@ add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) -install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp index 762cd96d23bb..942b04b3b50a 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.cpp @@ -1,38 +1,9 @@ -#include +#include "memory.h" -#include -#include -#include -#include #include #include -namespace openreg { -namespace internal { - -class ScopedMemoryProtector { - public: - ScopedMemoryProtector(const orPointerAttributes& info) - : m_info(info), m_protected(false) { - if (m_info.type == orMemoryType::orMemoryTypeDevice) { - if (mprotect(m_info.pointer, m_info.size, PROT_READ | PROT_WRITE) == - 0) { - m_protected = true; - } - } - } - ~ScopedMemoryProtector() { - if (m_protected) { - mprotect(m_info.pointer, m_info.size, PROT_NONE); - } - } - ScopedMemoryProtector(const ScopedMemoryProtector&) = delete; - ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete; - - private: - orPointerAttributes m_info; - bool m_protected; -}; +namespace { class MemoryManager { public: @@ -46,7 +17,7 @@ class MemoryManager { return orErrorUnknown; std::lock_guard lock(m_mutex); - long page_size = sysconf(_SC_PAGESIZE); + long page_size = openreg::get_pagesize(); size_t aligned_size = ((size - 1) / page_size + 1) * page_size; void* mem = nullptr; int current_device = -1; @@ -54,21 +25,15 @@ class MemoryManager { if (type == orMemoryType::orMemoryTypeDevice) { orGetDevice(¤t_device); - mem = mmap( - nullptr, - aligned_size, - PROT_READ | PROT_WRITE, - MAP_PRIVATE | MAP_ANONYMOUS, - -1, - 0); - if (mem == MAP_FAILED) + mem = openreg::mmap(aligned_size); + if (mem == nullptr) return orErrorUnknown; - if (mprotect(mem, aligned_size, PROT_NONE) != 0) { - munmap(mem, aligned_size); + if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) { + openreg::munmap(mem, aligned_size); return orErrorUnknown; } } else { - if (posix_memalign(&mem, page_size, aligned_size) != 0) { + if (openreg::alloc(&mem, page_size, aligned_size) != 0) { return orErrorUnknown; } } @@ -87,11 +52,12 @@ class MemoryManager { if (it == m_registry.end()) return orErrorUnknown; const auto& info = it->second; + if (info.type == orMemoryType::orMemoryTypeDevice) { - mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE); - munmap(info.pointer, info.size); + openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE); + openreg::munmap(info.pointer, info.size); } else { - ::free(info.pointer); + openreg::free(info.pointer); } m_registry.erase(it); return orSuccess; @@ -167,7 +133,8 @@ class MemoryManager { if (info.type != orMemoryType::orMemoryTypeDevice) { return orErrorUnknown; } - if (mprotect(info.pointer, info.size, PROT_READ | PROT_WRITE) != 0) { + if (openreg::mprotect( + info.pointer, info.size, F_PROT_READ | F_PROT_WRITE) != 0) { return orErrorUnknown; } return orSuccess; @@ -179,49 +146,75 @@ class MemoryManager { if (info.type != orMemoryType::orMemoryTypeDevice) { return orErrorUnknown; } - if (mprotect(info.pointer, info.size, PROT_NONE) != 0) { + if (openreg::mprotect(info.pointer, info.size, F_PROT_NONE) != 0) { return orErrorUnknown; } return orSuccess; } private: + class ScopedMemoryProtector { + public: + ScopedMemoryProtector(const orPointerAttributes& info) + : m_info(info), m_protected(false) { + if (m_info.type == orMemoryType::orMemoryTypeDevice) { + if (openreg::mprotect( + m_info.pointer, m_info.size, F_PROT_READ | F_PROT_WRITE) == 0) { + m_protected = true; + } + } + } + ~ScopedMemoryProtector() { + if (m_protected) { + openreg::mprotect(m_info.pointer, m_info.size, F_PROT_NONE); + } + } + ScopedMemoryProtector(const ScopedMemoryProtector&) = delete; + ScopedMemoryProtector& operator=(const ScopedMemoryProtector&) = delete; + + private: + orPointerAttributes m_info; + bool m_protected; + }; + MemoryManager() = default; + orPointerAttributes getPointerInfo(const void* ptr) { auto it = m_registry.upper_bound(const_cast(ptr)); - if (it == m_registry.begin()) - return {}; - --it; - const char* p_char = static_cast(ptr); - const char* base_char = static_cast(it->first); - if (p_char >= base_char && p_char < (base_char + it->second.size)) { - return it->second; + if (it != m_registry.begin()) { + --it; + const char* p_char = static_cast(ptr); + const char* base_char = static_cast(it->first); + if (p_char >= base_char && p_char < (base_char + it->second.size)) { + return it->second; + } } + return {}; } + std::map m_registry; std::mutex m_mutex; }; -} // namespace internal -} // namespace openreg +} // namespace orError_t orMalloc(void** devPtr, size_t size) { - return openreg::internal::MemoryManager::getInstance().allocate( + return MemoryManager::getInstance().allocate( devPtr, size, orMemoryType::orMemoryTypeDevice); } orError_t orFree(void* devPtr) { - return openreg::internal::MemoryManager::getInstance().free(devPtr); + return MemoryManager::getInstance().free(devPtr); } orError_t orMallocHost(void** hostPtr, size_t size) { - return openreg::internal::MemoryManager::getInstance().allocate( + return MemoryManager::getInstance().allocate( hostPtr, size, orMemoryType::orMemoryTypeHost); } orError_t orFreeHost(void* hostPtr) { - return openreg::internal::MemoryManager::getInstance().free(hostPtr); + return MemoryManager::getInstance().free(hostPtr); } orError_t orMemcpy( @@ -229,21 +222,19 @@ orError_t orMemcpy( const void* src, size_t count, orMemcpyKind kind) { - return openreg::internal::MemoryManager::getInstance().memcpy( - dst, src, count, kind); + return MemoryManager::getInstance().memcpy(dst, src, count, kind); } orError_t orPointerGetAttributes( orPointerAttributes* attributes, const void* ptr) { - return openreg::internal::MemoryManager::getInstance().getPointerAttributes( - attributes, ptr); + return MemoryManager::getInstance().getPointerAttributes(attributes, ptr); } orError_t orMemoryUnprotect(void* devPtr) { - return openreg::internal::MemoryManager::getInstance().unprotect(devPtr); + return MemoryManager::getInstance().unprotect(devPtr); } orError_t orMemoryProtect(void* devPtr) { - return openreg::internal::MemoryManager::getInstance().protect(devPtr); + return MemoryManager::getInstance().protect(devPtr); } diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h new file mode 100644 index 000000000000..9de13acc2350 --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/csrc/memory.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include +#include + +#include + +#if defined(_WIN32) +#include +#else +#include +#include +#endif + +#define F_PROT_NONE 0x0 +#define F_PROT_READ 0x1 +#define F_PROT_WRITE 0x2 + +namespace openreg { + +void* mmap(size_t size) { +#if defined(_WIN32) + return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); +#else + void* addr = ::mmap( + nullptr, + size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); + return (addr == MAP_FAILED) ? nullptr : addr; +#endif +} + +void munmap(void* addr, size_t size) { +#if defined(_WIN32) + VirtualFree(addr, 0, MEM_RELEASE); +#else + ::munmap(addr, size); +#endif +} + +int mprotect(void* addr, size_t size, int prot) { +#if defined(_WIN32) + DWORD win_prot = 0; + DWORD old; + if (prot == F_PROT_NONE) { + win_prot = PAGE_NOACCESS; + } else { + win_prot = PAGE_READWRITE; + } + + return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1; +#else + int native_prot = 0; + if (prot == F_PROT_NONE) + native_prot = PROT_NONE; + else { + if (prot & F_PROT_READ) + native_prot |= PROT_READ; + if (prot & F_PROT_WRITE) + native_prot |= PROT_WRITE; + } + + return ::mprotect(addr, size, native_prot); +#endif +} + +int alloc(void** mem, size_t alignment, size_t size) { +#ifdef _WIN32 + *mem = _aligned_malloc(size, alignment); + return *mem ? 0 : -1; +#else + return posix_memalign(mem, alignment, size); +#endif +} + +void free(void* mem) { +#ifdef _WIN32 + _aligned_free(mem); +#else + ::free(mem); +#endif +} + +long get_pagesize() { +#ifdef _WIN32 + SYSTEM_INFO si; + GetSystemInfo(&si); + return static_cast(si.dwPageSize); +#else + return sysconf(_SC_PAGESIZE); +#endif +} + +} // namespace openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h index b6b0b3da4295..a5e8b77c421c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/include/openreg.h @@ -2,6 +2,12 @@ #include +#ifdef _WIN32 + #define OPENREG_EXPORT __declspec(dllexport) +#else + #define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + #ifdef __cplusplus extern "C" { #endif @@ -28,19 +34,19 @@ struct orPointerAttributes { size_t size; }; -orError_t orMalloc(void** devPtr, size_t size); -orError_t orFree(void* devPtr); -orError_t orMallocHost(void** hostPtr, size_t size); -orError_t orFreeHost(void* hostPtr); -orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); -orError_t orMemoryUnprotect(void* devPtr); -orError_t orMemoryProtect(void* devPtr); +OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size); +OPENREG_EXPORT orError_t orFree(void* devPtr); +OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size); +OPENREG_EXPORT orError_t orFreeHost(void* hostPtr); +OPENREG_EXPORT orError_t orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); +OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr); +OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr); -orError_t orGetDeviceCount(int* count); -orError_t orSetDevice(int device); -orError_t orGetDevice(int* device); +OPENREG_EXPORT orError_t orGetDeviceCount(int* count); +OPENREG_EXPORT orError_t orSetDevice(int device); +OPENREG_EXPORT orError_t orGetDevice(int* device); -orError_t orPointerGetAttributes( +OPENREG_EXPORT orError_t orPointerGetAttributes( orPointerAttributes* attributes, const void* ptr); diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py index 3ed73794b06d..45b2343070fe 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py @@ -1,5 +1,15 @@ +import sys + import torch + +if sys.platform == "win32": + from ._utils import _load_dll_libraries + + _load_dll_libraries() + del _load_dll_libraries + + import torch_openreg._C # type: ignore[misc] import torch_openreg.openreg diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/_utils.py b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/_utils.py new file mode 100644 index 000000000000..1c26f475ba7a --- /dev/null +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/_utils.py @@ -0,0 +1,42 @@ +import ctypes +import glob +import os + + +def _load_dll_libraries() -> None: + openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib") + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + kernel32.LoadLibraryW.restype = ctypes.c_void_p + if with_load_library_flags: + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + os.add_dll_directory(openreg_dll_path) + + dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll")) + path_patched = False + for dll in dlls: + is_loaded = False + if with_load_library_flags: + res = kernel32.LoadLibraryExW(dll, None, 0x00001100) + last_error = ctypes.get_last_error() + if res is None and last_error != 126: + err = ctypes.WinError(last_error) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + elif res is not None: + is_loaded = True + if not is_loaded: + if not path_patched: + os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]]) + path_patched = True + res = kernel32.LoadLibraryW(dll) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + + kernel32.SetErrorMode(prev_error_mode) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt index 574b5b1c748a..4ff321c43f2c 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/CMakeLists.txt @@ -6,7 +6,19 @@ file(GLOB_RECURSE SOURCE_FILES add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) -target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python torch_openreg) +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg) + +if(WIN32) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES}) +elseif(APPLE) + set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") +endif() + target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib) -install(TARGETS ${LIBRARY_NAME} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp index 4acdbfc8e1dc..38c456339003 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp @@ -90,7 +90,7 @@ static PyMethodDef methods[] = { * Therefore, it cannot be named initModule here, otherwise initModule * in torch/csrc/Module.cpp will be called, resulting in failure. */ -extern "C" PyObject* initOpenRegModule(void) { +extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) { static struct PyModuleDef openreg_C_module = { PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods}; PyObject* mod = PyModule_Create(&openreg_C_module); diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c index cd3eb4fe1ecc..243a43a37e5e 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/stub.c @@ -1,13 +1,18 @@ #include -extern PyObject* initOpenRegModule(void); +#ifdef _WIN32 + #define OPENREG_EXPORT __declspec(dllexport) +#else + #define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + +extern OPENREG_EXPORT PyObject* initOpenRegModule(void); -#ifndef _WIN32 #ifdef __cplusplus extern "C" #endif -__attribute__((visibility("default"))) PyObject* PyInit__C(void); -#endif + +OPENREG_EXPORT PyObject* PyInit__C(void); PyMODINIT_FUNC PyInit__C(void) { diff --git a/test/run_test.py b/test/run_test.py index e0bde4e6d52d..4983071ab23f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -28,7 +28,6 @@ from torch.testing._internal.common_utils import ( get_report_path, IS_CI, - IS_LINUX, IS_MACOS, retry_shell, set_cwd, @@ -911,10 +910,6 @@ def _test_autoload(test_directory, options, enable=True): def run_test_with_openreg(test_module, test_directory, options): - # TODO(FFFrog): Will remove this later when windows/macos are supported. - if not IS_LINUX: - return 0 - openreg_dir = os.path.join( test_directory, "cpp_extensions", "open_registration_extension", "torch_openreg" ) diff --git a/test/test_openreg.py b/test/test_openreg.py index cae20b16f479..a2418fbdef20 100644 --- a/test/test_openreg.py +++ b/test/test_openreg.py @@ -17,6 +17,7 @@ from torch.testing._internal.common_utils import ( run_tests, skipIfTorchDynamo, + skipIfWindows, skipIfXpu, TemporaryFileName, TestCase, @@ -284,6 +285,7 @@ def test_manual_seed(self): self.assertEqual(torch.openreg.initial_seed(), 2024) # type: ignore[misc] # Autograd + @skipIfWindows() def test_autograd_init(self): # Make sure autograd is initialized torch.ones(2, requires_grad=True, device="openreg").sum().backward()