-
Notifications
You must be signed in to change notification settings - Fork 12k
musa: extract ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex #13887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
CUBLAS_CHECK( | ||
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, | ||
ne01, ne11, ne10, | ||
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/nb00, | ||
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, s11, | ||
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne0, | ||
ne23, | ||
cu_compute_type, | ||
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||
|
||
CUDA_CHECK(cudaFree(ptrs_src)); | ||
CUDA_CHECK(cudaFree(ptrs_dst)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wouldn't be ok in CUDA, since cublasGemmBatchedEx
is normally asynchronous and freeing the memory immediately would likely lead to a use after free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a look at this!
Yes, that's due to the current incompatibility between cublasGemmBatchedEx
and mublasGemmBatchedEx
. Also, mublas.cu
is only compiled when using the MUSA backend.
This is not an acceptable solution to me. For this level of changes I would only be willing to accept vendor-specific code if
How about we just wait for the next release then? |
Thanks for the review! Here are a few considerations and experiments I explored prior to the last PR:
That’s certainly an option, but based on previous experience, MUSA SDK releases often take several months. So it may delay this improvement for quite a while. |
Just wanted to share some good news — in our internal build, |
Make sure to read the contributing guidelines before submitting a PR
This PR extracts
ggml_cuda_mul_mat_batched_cublas_gemm_batched_ex
and implements a MUSA-only version that allocates memory for pointer arrays usingcudaMalloc
, in order to avoid segmentation faults in muBLAS.Since #13842 is still open, I will rebase this PR once it is merged into
master
.Hopefully, we can revert this change in the next MUSA SDK release.
Testing Done
All tests below were performed on the MTT S4000.
./build/bin/test-backend-ops
passedDeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf
,qwen3_8b_q4_k_m.gguf
,nvidia-llama-3_1-nemotron-nano-8b-v1-q4_k_m.gguf
with or without the-fa
flag