Skip to content

Commit

Permalink
Add some minimal optimizations for CDNA (#10498)
Browse files Browse the repository at this point in the history
* Add some minimal optimizations for CDNA

* ggml_cuda: set launch bounds also for GCN as it helps there too
  • Loading branch information
IMbackK authored Nov 27, 2024
1 parent 46c69e0 commit 3ad5451
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 8 deletions.
17 changes: 14 additions & 3 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,20 @@
#define CC_TURING 750
#define CC_AMPERE 800
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)

// GCN/CNDA, wave size is 64
#define CC_GCN4 (CC_OFFSET_AMD + 803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define CC_VEGA (CC_OFFSET_AMD + 900) // Vega56/64, minimum for fp16 dual issue
#define CC_VEGA20 (CC_OFFSET_AMD + 906) // MI50/Radeon VII, minimum for dp4a
#define CC_CDNA (CC_OFFSET_AMD + 908) // MI100, minimum for MFMA, acc registers
#define CC_CDNA2 (CC_OFFSET_AMD + 910) // MI210, minimum acc register renameing
#define CC_CDNA3 (CC_OFFSET_AMD + 942) // MI300

// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
#define CC_RDNA1 (CC_OFFSET_AMD + 1010) // RX 5000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
#define CC_RDNA3 (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA

#define CC_QY1 210
#define CC_QY2 220

Expand Down
11 changes: 10 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1107,14 +1107,19 @@ static void ggml_cuda_op_mul_mat_cublas(
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;

cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}

CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
Expand Down Expand Up @@ -1607,6 +1612,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F;

if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
cu_compute_type = CUBLAS_COMPUTE_32F;
}

// dst strides
size_t nbd2 = dst->nb[2];
size_t nbd3 = dst->nb[3];
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2570,9 +2570,9 @@ static __device__ void mul_mat_q_process_tile(

template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2)
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2)
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else
#if __CUDA_ARCH__ >= CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmvq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;

if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
switch(ncols_y) {
case 1:
nwarps = 4;
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/vendors/hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@

#define __CUDA_ARCH__ 1300

#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
#define GCN
#endif

#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA
#endif

#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
Expand Down

0 comments on commit 3ad5451

Please sign in to comment.