Skip to content

Commit 873f19f

Browse files
authored
ggml : add quick GELU (leejet#254)
* Implement Quick GELU * Revert "Implement Quick GELU" This reverts commit ff220cc1f91a184f195d19b17ed4c352cc72a6f0. * Tidy up ggml.h * Respect to the style of ggml * Fix: Fix minor typo * Rename `quick_gelu` -> `gelu_quick`
1 parent f52d2a0 commit 873f19f

File tree

2 files changed

+153
-4
lines changed

2 files changed

+153
-4
lines changed

include/ggml/ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ extern "C" {
290290
GGML_OP_STEP,
291291
GGML_OP_RELU,
292292
GGML_OP_GELU,
293+
GGML_OP_GELU_QUICK,
293294
GGML_OP_SILU,
294295
GGML_OP_SILU_BACK,
295296
GGML_OP_NORM, // normalize
@@ -687,6 +688,14 @@ extern "C" {
687688
struct ggml_context * ctx,
688689
struct ggml_tensor * a);
689690

691+
GGML_API struct ggml_tensor * ggml_gelu_quick(
692+
struct ggml_context * ctx,
693+
struct ggml_tensor * a);
694+
695+
GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
696+
struct ggml_context * ctx,
697+
struct ggml_tensor * a);
698+
690699
GGML_API struct ggml_tensor * ggml_silu(
691700
struct ggml_context * ctx,
692701
struct ggml_tensor * a);

src/ggml.c

Lines changed: 144 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ typedef void* thread_ret_t;
9898
/*#define GGML_PERF*/
9999
#define GGML_DEBUG 0
100100
#define GGML_GELU_FP16
101+
#define GGML_GELU_QUICK_FP16
101102
#define GGML_SILU_FP16
102103

103104
#define GGML_SOFT_MAX_UNROLL 4
@@ -322,6 +323,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
322323
// precomputed gelu table for f16 (128 KB)
323324
static ggml_fp16_t table_gelu_f16[1 << 16];
324325

326+
// precomputed quick gelu table for f16 (128 KB)
327+
static ggml_fp16_t table_gelu_quick_f16[1 << 16];
328+
325329
// precomputed silu table for f16 (128 KB)
326330
static ggml_fp16_t table_silu_f16[1 << 16];
327331

@@ -3288,6 +3292,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
32883292
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
32893293

32903294
static const float GELU_COEF_A = 0.044715f;
3295+
static const float GELU_QUICK_COEF = -1.702f;
32913296
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
32923297

32933298
inline static float ggml_gelu_f32(float x) {
@@ -3318,6 +3323,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
33183323
}
33193324
#endif
33203325

3326+
inline static float ggml_gelu_quick_f32(float x) {
3327+
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
3328+
}
3329+
3330+
inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3331+
const uint16_t * i16 = (const uint16_t *) x;
3332+
for (int i = 0; i < n; ++i) {
3333+
y[i] = table_gelu_quick_f16[i16[i]];
3334+
}
3335+
}
3336+
3337+
#ifdef GGML_GELU_QUICK_FP16
3338+
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3339+
uint16_t t;
3340+
for (int i = 0; i < n; ++i) {
3341+
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3342+
memcpy(&t, &fp16, sizeof(uint16_t));
3343+
y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
3344+
}
3345+
}
3346+
#else
3347+
inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3348+
for (int i = 0; i < n; ++i) {
3349+
y[i] = ggml_gelu_quick_f32(x[i]);
3350+
}
3351+
}
3352+
#endif
3353+
33213354
// Sigmoid Linear Unit (SiLU) function
33223355
inline static float ggml_silu_f32(float x) {
33233356
return x/(1.0f + expf(-x));
@@ -3519,6 +3552,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
35193552
"STEP",
35203553
"RELU",
35213554
"GELU",
3555+
"GELU_QUICK",
35223556
"SILU",
35233557
"SILU_BACK",
35243558
"NORM",
@@ -3558,7 +3592,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
35583592
"MAP_BINARY",
35593593
};
35603594

3561-
static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54");
3595+
static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
35623596

35633597
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
35643598
"none",
@@ -3583,6 +3617,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
35833617
"step(x)",
35843618
"relu(x)",
35853619
"gelu(x)",
3620+
"gelu_quick(x)",
35863621
"silu(x)",
35873622
"silu_back(x)",
35883623
"norm(x)",
@@ -3622,7 +3657,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
36223657
"f(x,y)",
36233658
};
36243659

3625-
static_assert(GGML_OP_COUNT == 54, "GGML_OP_COUNT != 54");
3660+
static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
36263661

36273662
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
36283663
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3899,7 +3934,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
38993934
// initialize time system (required on Windows)
39003935
ggml_time_init();
39013936

3902-
// initialize GELU, SILU and EXP F32 tables
3937+
// initialize GELU, Quick GELU, SILU and EXP F32 tables
39033938
{
39043939
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
39053940

@@ -3909,13 +3944,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
39093944
memcpy(&ii, &ui, sizeof(ii));
39103945
const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
39113946
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
3947+
table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
39123948
table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
39133949
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
39143950
}
39153951

39163952
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
39173953

3918-
GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3954+
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
39193955
}
39203956

39213957
// initialize g_state
@@ -5271,6 +5307,40 @@ struct ggml_tensor * ggml_gelu_inplace(
52715307
return ggml_gelu_impl(ctx, a, true);
52725308
}
52735309

5310+
// ggml_gelu_quick
5311+
5312+
struct ggml_tensor * ggml_gelu_quick_impl(
5313+
struct ggml_context * ctx,
5314+
struct ggml_tensor * a,
5315+
bool inplace) {
5316+
bool is_node = false;
5317+
5318+
if (!inplace && (a->grad)) {
5319+
is_node = true;
5320+
}
5321+
5322+
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5323+
5324+
result->op = GGML_OP_GELU_QUICK;
5325+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5326+
result->src0 = a;
5327+
result->src1 = NULL;
5328+
5329+
return result;
5330+
}
5331+
5332+
struct ggml_tensor * ggml_gelu_quick(
5333+
struct ggml_context * ctx,
5334+
struct ggml_tensor * a) {
5335+
return ggml_gelu_quick_impl(ctx, a, false);
5336+
}
5337+
5338+
struct ggml_tensor * ggml_gelu_quick_inplace(
5339+
struct ggml_context * ctx,
5340+
struct ggml_tensor * a) {
5341+
return ggml_gelu_quick_impl(ctx, a, true);
5342+
}
5343+
52745344
// ggml_silu
52755345

52765346
struct ggml_tensor * ggml_silu_impl(
@@ -9118,6 +9188,67 @@ static void ggml_compute_forward_gelu(
91189188
//printf("XXXXXXXX gelu\n");
91199189
}
91209190

9191+
// ggml_compute_forward_gelu_quick
9192+
9193+
static void ggml_compute_forward_gelu_quick_f32(
9194+
const struct ggml_compute_params * params,
9195+
const struct ggml_tensor * src0,
9196+
struct ggml_tensor * dst) {
9197+
GGML_ASSERT(ggml_is_contiguous(src0));
9198+
GGML_ASSERT(ggml_is_contiguous(dst));
9199+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
9200+
9201+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9202+
return;
9203+
}
9204+
9205+
const int ith = params->ith;
9206+
const int nth = params->nth;
9207+
9208+
const int nc = src0->ne[0];
9209+
const int nr = ggml_nrows(src0);
9210+
9211+
// rows per thread
9212+
const int dr = (nr + nth - 1)/nth;
9213+
9214+
// row range for this thread
9215+
const int ir0 = dr*ith;
9216+
const int ir1 = MIN(ir0 + dr, nr);
9217+
9218+
for (int i1 = ir0; i1 < ir1; i1++) {
9219+
ggml_vec_gelu_quick_f32(nc,
9220+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
9221+
(float *) ((char *) src0->data + i1*(src0->nb[1])));
9222+
9223+
#ifndef NDEBUG
9224+
for (int k = 0; k < nc; k++) {
9225+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
9226+
UNUSED(x);
9227+
assert(!isnan(x));
9228+
assert(!isinf(x));
9229+
}
9230+
#endif
9231+
}
9232+
}
9233+
9234+
static void ggml_compute_forward_gelu_quick(
9235+
const struct ggml_compute_params * params,
9236+
const struct ggml_tensor * src0,
9237+
struct ggml_tensor * dst) {
9238+
switch (src0->type) {
9239+
case GGML_TYPE_F32:
9240+
{
9241+
ggml_compute_forward_gelu_quick_f32(params, src0, dst);
9242+
} break;
9243+
default:
9244+
{
9245+
GGML_ASSERT(false);
9246+
} break;
9247+
}
9248+
9249+
//printf("XXXXXXXX quick gelu\n");
9250+
}
9251+
91219252
// ggml_compute_forward_silu
91229253

91239254
static void ggml_compute_forward_silu_f32(
@@ -13360,6 +13491,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1336013491
{
1336113492
ggml_compute_forward_gelu(params, tensor->src0, tensor);
1336213493
} break;
13494+
case GGML_OP_GELU_QUICK:
13495+
{
13496+
ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
13497+
} break;
1336313498
case GGML_OP_SILU:
1336413499
{
1336513500
ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -13771,6 +13906,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1377113906
{
1377213907
GGML_ASSERT(false); // TODO: not implemented
1377313908
} break;
13909+
case GGML_OP_GELU_QUICK:
13910+
{
13911+
GGML_ASSERT(false); // TODO: not implemented
13912+
} break;
1377413913
case GGML_OP_ALIBI:
1377513914
{
1377613915
GGML_ASSERT(false); // TODO: not implemented
@@ -14578,6 +14717,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1457814717
} break;
1457914718
case GGML_OP_MUL:
1458014719
case GGML_OP_GELU:
14720+
case GGML_OP_GELU_QUICK:
1458114721
case GGML_OP_SILU:
1458214722
case GGML_OP_SILU_BACK:
1458314723
case GGML_OP_NORM:

0 commit comments

Comments
 (0)