From ea558e145985ade6a9f337c828831ef1945cd4bd Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Fri, 26 Jun 2020 11:38:40 -0400 Subject: [PATCH] Propagate nvrtc errors up the stack in AFError exceptions * NVRTC errors were only printed in debug builds. The error messages were not passed to the exceptions thrown by the lib. This made it harder to debug issues --- src/backend/cuda/compile_module.cpp | 72 ++++++++++++++--------------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/src/backend/cuda/compile_module.cpp b/src/backend/cuda/compile_module.cpp index ee4ce27e49..1f54aa8079 100644 --- a/src/backend/cuda/compile_module.cpp +++ b/src/backend/cuda/compile_module.cpp @@ -81,47 +81,44 @@ using std::chrono::duration_cast; using std::chrono::high_resolution_clock; using std::chrono::milliseconds; -#ifdef NDEBUG -#define CU_LINK_CHECK(fn) \ - do { \ - CUresult res = fn; \ - if (res == CUDA_SUCCESS) break; \ - char cu_err_msg[2048]; \ - const char *cu_err_name; \ - cuGetErrorName(res, &cu_err_name); \ - snprintf(cu_err_msg, sizeof(cu_err_msg), "CU Error %s(%d): %s\n", \ - cu_err_name, (int)(res), linkError); \ - AF_ERROR(cu_err_msg, AF_ERR_INTERNAL); \ +#define CU_LINK_CHECK(fn) \ + do { \ + CUresult res = (fn); \ + if (res == CUDA_SUCCESS) break; \ + array cu_err_msg; \ + const char *cu_err_name; \ + cuGetErrorName(res, &cu_err_name); \ + snprintf(cu_err_msg.data(), cu_err_msg.size(), \ + "CU Link Error %s(%d): %s\n", cu_err_name, (int)(res), \ + linkError); \ + AF_ERROR(cu_err_msg.data(), AF_ERR_INTERNAL); \ } while (0) -#else -#define CU_LINK_CHECK(fn) CU_CHECK(fn) -#endif -#ifndef NDEBUG -#define NVRTC_CHECK(fn) \ - do { \ - nvrtcResult res = fn; \ - if (res == NVRTC_SUCCESS) break; \ - size_t logSize; \ - nvrtcGetProgramLogSize(prog, &logSize); \ - unique_ptr log(new char[logSize + 1]); \ - char *logptr = log.get(); \ - nvrtcGetProgramLog(prog, logptr); \ - logptr[logSize] = '\x0'; \ - puts(logptr); \ - AF_ERROR("NVRTC ERROR", AF_ERR_INTERNAL); \ - } while (0) -#else #define NVRTC_CHECK(fn) \ do { \ nvrtcResult res = (fn); \ if (res == NVRTC_SUCCESS) break; \ - char nvrtc_err_msg[2048]; \ - snprintf(nvrtc_err_msg, sizeof(nvrtc_err_msg), \ + array nvrtc_err_msg; \ + snprintf(nvrtc_err_msg.data(), nvrtc_err_msg.size(), \ "NVRTC Error(%d): %s\n", res, nvrtcGetErrorString(res)); \ - AF_ERROR(nvrtc_err_msg, AF_ERR_INTERNAL); \ + AF_ERROR(nvrtc_err_msg.data(), AF_ERR_INTERNAL); \ + } while (0) + +#define NVRTC_COMPILE_CHECK(fn) \ + do { \ + nvrtcResult res = (fn); \ + if (res == NVRTC_SUCCESS) break; \ + size_t logSize; \ + nvrtcGetProgramLogSize(prog, &logSize); \ + vector log(logSize + 1); \ + nvrtcGetProgramLog(prog, log.data()); \ + log[logSize] = '\0'; \ + array nvrtc_err_msg; \ + snprintf(nvrtc_err_msg.data(), nvrtc_err_msg.size(), \ + "NVRTC Error(%d): %s\nLog: \n%s\n", res, \ + nvrtcGetErrorString(res), log.data()); \ + AF_ERROR(nvrtc_err_msg.data(), AF_ERR_INTERNAL); \ } while (0) -#endif spdlog::logger *getLogger() { static std::shared_ptr logger(common::loggerFactory("jit")); @@ -264,8 +261,8 @@ Module compileModule(const string &moduleKey, const vector &sources, } auto compile = high_resolution_clock::now(); - NVRTC_CHECK(nvrtcCompileProgram(prog, compiler_options.size(), - compiler_options.data())); + NVRTC_COMPILE_CHECK(nvrtcCompileProgram(prog, compiler_options.size(), + compiler_options.data())); auto compile_end = high_resolution_clock::now(); size_t ptx_size; vector ptx; @@ -273,7 +270,7 @@ Module compileModule(const string &moduleKey, const vector &sources, ptx.resize(ptx_size); NVRTC_CHECK(nvrtcGetPTX(prog, ptx.data())); - const size_t linkLogSize = 1024; + const size_t linkLogSize = 4096; char linkInfo[linkLogSize] = {0}; char linkError[linkLogSize] = {0}; @@ -367,8 +364,7 @@ Module compileModule(const string &moduleKey, const vector &sources, return lhs + ", " + rhs; }); }; - AF_TRACE("{{{:<30} : {{ compile:{:>5} ms, link:{:>4} ms, {{ {} }}, {} }}}}", - sources[0], + AF_TRACE("{{{compile:{:>5} ms, link:{:>4} ms, {{ {} }}, {} }}}", duration_cast(compile_end - compile).count(), duration_cast(link_end - link).count(), listOpts(compiler_options), getDeviceProp(device).name);