@@ -81,22 +81,12 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
81
81
state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
82
82
.task ("mapContextFFN" , TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer , context , state .wrapXb ,
83
83
state .wrapX , weights .rms_ffn_weightLayered [layerIndex ], state .tempFFN )
84
- // Before (3 tasks):
85
- // .task("wGateUp", ...)
86
- // .task("gateUpSiLU", ...)
87
- // .task("wDown", ...)
88
-
89
- // After (1 fused task):
90
- .task ("fusedFFN" , TransformerComputeKernelsLayered ::fusedGateUpSiLUDown , context ,
91
- state .wrapXb , state .wrapX , weights .wUpLayered [layerIndex ],
92
- weights .wDownLayered [layerIndex ], config .dim (), config .hiddenDim (),
93
- LOCAL_WORK_GROUP_SIZE_ALLOC )
94
- // .task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
95
- // state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
96
- // .task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
97
- // state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
98
- // .task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
99
- // state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
84
+ .task ("wGateUp" , TransformerComputeKernelsLayered ::matrixVectorGeneric , context ,
85
+ state .wrapXb , state .wrapHb , weights .wUpLayered [layerIndex ], config .dim (), 2 * config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
86
+ .task ("gateUpSiLU" , TransformerComputeKernelsLayered ::splitGateUpAndSiLU ,
87
+ state .wrapHb , state .wrapHbG , state .wrapHbU , config .hiddenDim ())
88
+ .task ("wDown" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
89
+ state .wrapHbU , state .wrapX , weights .wDownLayered [layerIndex ], config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
100
90
.persistOnDevice (
101
91
state .wrapX
102
92
);
@@ -348,9 +338,7 @@ private GridScheduler setupGridSchedulersLayered() {
348
338
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".parallel-attention" , parallelAttentionWorker );
349
339
tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".copyToCaches" , copyToCachesWorker );
350
340
// New FFN tasks
351
- // tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
352
- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".fusedFFN" , configDimRowMajorGlobalWorker );
353
-
341
+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".gateUpSiLU" , splitGateUpSiLUWorker );
354
342
}
355
343
356
344
// Vocabulary worker configuration
0 commit comments