@@ -98,6 +98,7 @@ typedef void* thread_ret_t;
98
98
/*#define GGML_PERF*/
99
99
#define GGML_DEBUG 0
100
100
#define GGML_GELU_FP16
101
+ #define GGML_GELU_QUICK_FP16
101
102
#define GGML_SILU_FP16
102
103
103
104
#define GGML_SOFT_MAX_UNROLL 4
@@ -322,6 +323,9 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
322
323
// precomputed gelu table for f16 (128 KB)
323
324
static ggml_fp16_t table_gelu_f16[1 << 16];
324
325
326
+ // precomputed quick gelu table for f16 (128 KB)
327
+ static ggml_fp16_t table_gelu_quick_f16[1 << 16];
328
+
325
329
// precomputed silu table for f16 (128 KB)
326
330
static ggml_fp16_t table_silu_f16[1 << 16];
327
331
@@ -3288,6 +3292,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
3288
3292
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
3289
3293
3290
3294
static const float GELU_COEF_A = 0.044715f;
3295
+ static const float GELU_QUICK_COEF = -1.702f;
3291
3296
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
3292
3297
3293
3298
inline static float ggml_gelu_f32(float x) {
@@ -3318,6 +3323,34 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
3318
3323
}
3319
3324
#endif
3320
3325
3326
+ inline static float ggml_gelu_quick_f32(float x) {
3327
+ return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
3328
+ }
3329
+
3330
+ inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
3331
+ const uint16_t * i16 = (const uint16_t *) x;
3332
+ for (int i = 0; i < n; ++i) {
3333
+ y[i] = table_gelu_quick_f16[i16[i]];
3334
+ }
3335
+ }
3336
+
3337
+ #ifdef GGML_GELU_QUICK_FP16
3338
+ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3339
+ uint16_t t;
3340
+ for (int i = 0; i < n; ++i) {
3341
+ ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
3342
+ memcpy(&t, &fp16, sizeof(uint16_t));
3343
+ y[i] = GGML_FP16_TO_FP32(table_gelu_quick_f16[t]);
3344
+ }
3345
+ }
3346
+ #else
3347
+ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) {
3348
+ for (int i = 0; i < n; ++i) {
3349
+ y[i] = ggml_gelu_quick_f32(x[i]);
3350
+ }
3351
+ }
3352
+ #endif
3353
+
3321
3354
// Sigmoid Linear Unit (SiLU) function
3322
3355
inline static float ggml_silu_f32(float x) {
3323
3356
return x/(1.0f + expf(-x));
@@ -3519,6 +3552,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3519
3552
"STEP",
3520
3553
"RELU",
3521
3554
"GELU",
3555
+ "GELU_QUICK",
3522
3556
"SILU",
3523
3557
"SILU_BACK",
3524
3558
"NORM",
@@ -3558,7 +3592,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3558
3592
"MAP_BINARY",
3559
3593
};
3560
3594
3561
- static_assert(GGML_OP_COUNT == 54 , "GGML_OP_COUNT != 54 ");
3595
+ static_assert(GGML_OP_COUNT == 55 , "GGML_OP_COUNT != 55 ");
3562
3596
3563
3597
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3564
3598
"none",
@@ -3583,6 +3617,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3583
3617
"step(x)",
3584
3618
"relu(x)",
3585
3619
"gelu(x)",
3620
+ "gelu_quick(x)",
3586
3621
"silu(x)",
3587
3622
"silu_back(x)",
3588
3623
"norm(x)",
@@ -3622,7 +3657,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3622
3657
"f(x,y)",
3623
3658
};
3624
3659
3625
- static_assert(GGML_OP_COUNT == 54 , "GGML_OP_COUNT != 54 ");
3660
+ static_assert(GGML_OP_COUNT == 55 , "GGML_OP_COUNT != 55 ");
3626
3661
3627
3662
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3628
3663
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3899,7 +3934,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3899
3934
// initialize time system (required on Windows)
3900
3935
ggml_time_init();
3901
3936
3902
- // initialize GELU, SILU and EXP F32 tables
3937
+ // initialize GELU, Quick GELU, SILU and EXP F32 tables
3903
3938
{
3904
3939
const uint64_t t_start = ggml_time_us(); UNUSED(t_start);
3905
3940
@@ -3909,13 +3944,14 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
3909
3944
memcpy(&ii, &ui, sizeof(ii));
3910
3945
const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
3911
3946
table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
3947
+ table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
3912
3948
table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
3913
3949
table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
3914
3950
}
3915
3951
3916
3952
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
3917
3953
3918
- GGML_PRINT_DEBUG("%s: GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3954
+ GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3919
3955
}
3920
3956
3921
3957
// initialize g_state
@@ -5271,6 +5307,40 @@ struct ggml_tensor * ggml_gelu_inplace(
5271
5307
return ggml_gelu_impl(ctx, a, true);
5272
5308
}
5273
5309
5310
+ // ggml_gelu_quick
5311
+
5312
+ struct ggml_tensor * ggml_gelu_quick_impl(
5313
+ struct ggml_context * ctx,
5314
+ struct ggml_tensor * a,
5315
+ bool inplace) {
5316
+ bool is_node = false;
5317
+
5318
+ if (!inplace && (a->grad)) {
5319
+ is_node = true;
5320
+ }
5321
+
5322
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5323
+
5324
+ result->op = GGML_OP_GELU_QUICK;
5325
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5326
+ result->src0 = a;
5327
+ result->src1 = NULL;
5328
+
5329
+ return result;
5330
+ }
5331
+
5332
+ struct ggml_tensor * ggml_gelu_quick(
5333
+ struct ggml_context * ctx,
5334
+ struct ggml_tensor * a) {
5335
+ return ggml_gelu_quick_impl(ctx, a, false);
5336
+ }
5337
+
5338
+ struct ggml_tensor * ggml_gelu_quick_inplace(
5339
+ struct ggml_context * ctx,
5340
+ struct ggml_tensor * a) {
5341
+ return ggml_gelu_quick_impl(ctx, a, true);
5342
+ }
5343
+
5274
5344
// ggml_silu
5275
5345
5276
5346
struct ggml_tensor * ggml_silu_impl(
@@ -9118,6 +9188,67 @@ static void ggml_compute_forward_gelu(
9118
9188
//printf("XXXXXXXX gelu\n");
9119
9189
}
9120
9190
9191
+ // ggml_compute_forward_gelu_quick
9192
+
9193
+ static void ggml_compute_forward_gelu_quick_f32(
9194
+ const struct ggml_compute_params * params,
9195
+ const struct ggml_tensor * src0,
9196
+ struct ggml_tensor * dst) {
9197
+ GGML_ASSERT(ggml_is_contiguous(src0));
9198
+ GGML_ASSERT(ggml_is_contiguous(dst));
9199
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
9200
+
9201
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
9202
+ return;
9203
+ }
9204
+
9205
+ const int ith = params->ith;
9206
+ const int nth = params->nth;
9207
+
9208
+ const int nc = src0->ne[0];
9209
+ const int nr = ggml_nrows(src0);
9210
+
9211
+ // rows per thread
9212
+ const int dr = (nr + nth - 1)/nth;
9213
+
9214
+ // row range for this thread
9215
+ const int ir0 = dr*ith;
9216
+ const int ir1 = MIN(ir0 + dr, nr);
9217
+
9218
+ for (int i1 = ir0; i1 < ir1; i1++) {
9219
+ ggml_vec_gelu_quick_f32(nc,
9220
+ (float *) ((char *) dst->data + i1*( dst->nb[1])),
9221
+ (float *) ((char *) src0->data + i1*(src0->nb[1])));
9222
+
9223
+ #ifndef NDEBUG
9224
+ for (int k = 0; k < nc; k++) {
9225
+ const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
9226
+ UNUSED(x);
9227
+ assert(!isnan(x));
9228
+ assert(!isinf(x));
9229
+ }
9230
+ #endif
9231
+ }
9232
+ }
9233
+
9234
+ static void ggml_compute_forward_gelu_quick(
9235
+ const struct ggml_compute_params * params,
9236
+ const struct ggml_tensor * src0,
9237
+ struct ggml_tensor * dst) {
9238
+ switch (src0->type) {
9239
+ case GGML_TYPE_F32:
9240
+ {
9241
+ ggml_compute_forward_gelu_quick_f32(params, src0, dst);
9242
+ } break;
9243
+ default:
9244
+ {
9245
+ GGML_ASSERT(false);
9246
+ } break;
9247
+ }
9248
+
9249
+ //printf("XXXXXXXX quick gelu\n");
9250
+ }
9251
+
9121
9252
// ggml_compute_forward_silu
9122
9253
9123
9254
static void ggml_compute_forward_silu_f32(
@@ -13360,6 +13491,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
13360
13491
{
13361
13492
ggml_compute_forward_gelu(params, tensor->src0, tensor);
13362
13493
} break;
13494
+ case GGML_OP_GELU_QUICK:
13495
+ {
13496
+ ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
13497
+ } break;
13363
13498
case GGML_OP_SILU:
13364
13499
{
13365
13500
ggml_compute_forward_silu(params, tensor->src0, tensor);
@@ -13771,6 +13906,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
13771
13906
{
13772
13907
GGML_ASSERT(false); // TODO: not implemented
13773
13908
} break;
13909
+ case GGML_OP_GELU_QUICK:
13910
+ {
13911
+ GGML_ASSERT(false); // TODO: not implemented
13912
+ } break;
13774
13913
case GGML_OP_ALIBI:
13775
13914
{
13776
13915
GGML_ASSERT(false); // TODO: not implemented
@@ -14578,6 +14717,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14578
14717
} break;
14579
14718
case GGML_OP_MUL:
14580
14719
case GGML_OP_GELU:
14720
+ case GGML_OP_GELU_QUICK:
14581
14721
case GGML_OP_SILU:
14582
14722
case GGML_OP_SILU_BACK:
14583
14723
case GGML_OP_NORM:
0 commit comments