Skip to content

Commit 8101786

Browse files
authored
CUDA: fix bug in rms_norm fusion (#15660)
* CUDA: fix bug in rms_norm fusion * Fix bug for OP_REPEAT * Fix index for add
1 parent 60e5eee commit 8101786

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
5757
const int i10 = i0 % ne10;
5858

5959
float result = src0_row ? (float) src0_row[i0] : 0.0f;
60-
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
60+
if constexpr (sizeof...(src1_ptrs) > 0) {
61+
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
62+
} else {
63+
result = bin_op(result, (float)src1[i_src1 + i10]);
64+
}
6165

6266
dst_row[i0] = (dst_t) result;
6367
}
@@ -96,7 +100,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
96100
const int i10 = i0 % ne10;
97101

98102
float result = src0_row ? (float) src0_row[i0] : 0.0f;
99-
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
103+
if constexpr (sizeof...(src1_ptrs) > 0) {
104+
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
105+
} else {
106+
result = bin_op(result, (float)src1[i_src1 + i10]);
107+
}
100108

101109
dst_row[i0] = (dst_t) result;
102110
}
@@ -231,23 +239,43 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
231239

232240
if (block_nums.z > 65535) {
233241
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
234-
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
235-
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
236-
ne0, ne1, ne2, ne3,
237-
ne10, ne11, ne12, ne13,
238-
/* s0, */ s1, s2, s3,
239-
/* s00,*/ s01, s02, s03,
240-
/* s10,*/ s11, s12,s13,
241-
(const src1_t *) dst->src[I + 1]->data...);
242+
if constexpr (sizeof...(I) > 0) {
243+
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
244+
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
245+
ne0, ne1, ne2, ne3,
246+
ne10, ne11, ne12, ne13,
247+
/* s0, */ s1, s2, s3,
248+
/* s00,*/ s01, s02, s03,
249+
/* s10,*/ s11, s12,s13,
250+
(const src1_t *) dst->src[I + 1]->data...);
251+
} else {
252+
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
253+
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
254+
ne0, ne1, ne2, ne3,
255+
ne10, ne11, ne12, ne13,
256+
/* s0, */ s1, s2, s3,
257+
/* s00,*/ s01, s02, s03,
258+
/* s10,*/ s11, s12,s13);
259+
}
242260
} else {
243-
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
244-
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
245-
ne0, ne1, ne2, ne3,
246-
ne10, ne11, ne12, ne13,
247-
/* s0, */ s1, s2, s3,
248-
/* s00,*/ s01, s02, s03,
249-
/* s10,*/ s11, s12,s13,
250-
(const src1_t *) dst->src[I + 1]->data...);
261+
if constexpr (sizeof...(I) > 0) {
262+
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
263+
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
264+
ne0, ne1, ne2, ne3,
265+
ne10, ne11, ne12, ne13,
266+
/* s0, */ s1, s2, s3,
267+
/* s00,*/ s01, s02, s03,
268+
/* s10,*/ s11, s12,s13,
269+
(const src1_t *) dst->src[I + 1]->data...);
270+
} else {
271+
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
272+
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
273+
ne0, ne1, ne2, ne3,
274+
ne10, ne11, ne12, ne13,
275+
/* s0, */ s1, s2, s3,
276+
/* s00,*/ s01, s02, s03,
277+
/* s10,*/ s11, s12,s13);
278+
}
251279
}
252280
}
253281
}
@@ -327,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
327355
}
328356

329357
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
330-
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
358+
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
331359
}
332360

333361
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2827,7 +2827,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28272827
const ggml_tensor *add = nullptr;
28282828

28292829
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
2830-
add = cgraph->nodes[node_idx+1];
2830+
add = cgraph->nodes[node_idx+2];
28312831
}
28322832

28332833
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);

ggml/src/ggml-cuda/norm.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
127127
const int add_nrows = 0,
128128
const int add_nchannels = 0,
129129
const int add_nsamples = 0) {
130+
130131
const int nrows = gridDim.x;
131132
const int nchannels = gridDim.y;
132133

@@ -135,6 +136,8 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
135136
const int sample = blockIdx.z;
136137
const int tid = threadIdx.x;
137138

139+
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
140+
138141
x += sample*stride_sample + channel*stride_channel + row*stride_row;
139142
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
140143

@@ -185,9 +188,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
185188
} else if constexpr (do_multiply) {
186189
const int mul_col = col % mul_ncols;
187190
dst[col] = scale * x[col] * mul[mul_col];
188-
} else if constexpr (do_add) {
189-
const int add_col = col % add_ncols;
190-
dst[col] += add[add_col];
191191
} else {
192192
dst[col] = scale * x[col];
193193
}

0 commit comments

Comments
 (0)