@@ -66,13 +66,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
66
66
.task ("splitQKV" , TransformerComputeKernelsLayered ::splitQKV ,
67
67
state .wrapQkv , state .wrapQ , state .wrapK , state .wrapV ,
68
68
config .dim (), config .headSize () * config .numberOfKeyValueHeads ())
69
- // .task("copyQ", TransformerComputeKernelsLayered::copyTo,
70
- // state.wrapQkv, 0, state.wrapQ,0, config.dim())
71
- // .task("copyK", TransformerComputeKernelsLayered::copyTo,
72
- // state.wrapQkv, config.dim(), state.wrapK, 0, config.headSize() * config.numberOfKeyValueHeads())
73
- // .task("copyV", TransformerComputeKernelsLayered::copyTo,
74
- // state.wrapQkv, config.dim() + config.headSize() * config.numberOfKeyValueHeads(),
75
- // state.wrapV, 0, config.headSize() * config.numberOfKeyValueHeads())
76
69
.task ("rope" , TransformerComputeKernelsLayered ::ropeRotationPhi3 ,context ,
77
70
state .positionHolder , state .wrapQ , state .wrapK , config .kvDim (),
78
71
config .headSize ())
@@ -88,13 +81,22 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
88
81
state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
89
82
.task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb ,
90
83
state .wrapX , weights .rms_ffn_weightLayered [layerIndex ], state .tempFFN )
91
- .task ("wGateUp" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
92
- state .wrapXb , state .wrapHb , weights .wUpLayered [layerIndex ], config .dim (), 2 * config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
93
- // Copy gate chunk: hb[0:hiddenDim] -> hbG[0:hiddenDim]
94
- .task ("gateUpSiLU" , TransformerComputeKernelsLayered ::splitGateUpAndSiLU ,
95
- state .wrapHb , state .wrapHbG , state .wrapHbU , config .hiddenDim ())
96
- .task ("wDown" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
97
- state .wrapHbU , state .wrapX , weights .wDownLayered [layerIndex ], config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
84
+ // Before (3 tasks):
85
+ // .task("wGateUp", ...)
86
+ // .task("gateUpSiLU", ...)
87
+ // .task("wDown", ...)
88
+
89
+ // After (1 fused task):
90
+ .task ("fusedFFN" , TransformerComputeKernelsLayered ::fusedGateUpSiLUDownOptimized , context ,
91
+ state .wrapXb , state .wrapX , weights .wUpLayered [layerIndex ],
92
+ weights .wDownLayered [layerIndex ], config .dim (), config .hiddenDim (),
93
+ LOCAL_WORK_GROUP_SIZE_ALLOC )
94
+ // .task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
95
+ // state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
96
+ // .task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
97
+ // state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
98
+ // .task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
99
+ // state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
98
100
.persistOnDevice (
99
101
state .wrapX
100
102
);
@@ -334,13 +336,7 @@ private GridScheduler setupGridSchedulersLayered() {
334
336
tornadoForwardScheduler .addWorkerGrid ("activationUpdate.updateX" , singleWorker );
335
337
for (int i = 0 ; i < config .numberOfLayers (); i ++) {
336
338
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".qkvmatmul" , qkvDimRowMajorGlobalWorker );
337
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyQ", copyQWorker);
338
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyK", copyKWorker);
339
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyV", copyVWorker);
340
-
341
339
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".splitQKV" , splitQKVWorker );
342
-
343
-
344
340
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope" , ropeWorker );
345
341
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".matmul1" , configDimRowMajorGlobalWorker );
346
342
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".wDown" , configDimRowMajorGlobalWorker );
@@ -352,12 +348,8 @@ private GridScheduler setupGridSchedulersLayered() {
352
348
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".parallel-attention" , parallelAttentionWorker );
353
349
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".copyToCaches" , copyToCachesWorker );
354
350
// New FFN tasks
355
- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".gateUpSiLU" , splitGateUpSiLUWorker );
356
-
357
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyGate", hiddenDimWorker);
358
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyUp", hiddenDimWorker);
359
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".siluActivation", hiddenDimWorker);
360
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gatedMultiply", hiddenDimWorker);
351
+ // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
352
+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".fusedFFN" , configDimRowMajorGlobalWorker );
361
353
362
354
}
363
355
0 commit comments