Skip to content

Commit 3ce5f36

Browse files
committed
Refactor Phi3TornadoVMLayerPlanner by reverting fusedGateUpSiLUDown to individual FFN tasks, updating worker grid mappings, and removing the fused kernel from TransformerComputeKernelsLayered for improved maintainability and clarity.
1 parent 9cbc5e5 commit 3ce5f36

File tree

2 files changed

+7
-73
lines changed

2 files changed

+7
-73
lines changed

src/main/java/com/example/tornadovm/Phi3TornadoVMLayerPlanner.java

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,12 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
8181
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
8282
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
8383
state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
84-
// Before (3 tasks):
85-
// .task("wGateUp", ...)
86-
// .task("gateUpSiLU", ...)
87-
// .task("wDown", ...)
88-
89-
// After (1 fused task):
90-
.task("fusedFFN", TransformerComputeKernelsLayered::fusedGateUpSiLUDown, context,
91-
state.wrapXb, state.wrapX, weights.wUpLayered[layerIndex],
92-
weights.wDownLayered[layerIndex], config.dim(), config.hiddenDim(),
93-
LOCAL_WORK_GROUP_SIZE_ALLOC)
94-
// .task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
95-
// state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
96-
// .task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
97-
// state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
98-
// .task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
99-
// state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
84+
.task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
85+
state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
86+
.task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
87+
state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
88+
.task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
89+
state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
10090
.persistOnDevice(
10191
state.wrapX
10292
);
@@ -348,9 +338,7 @@ private GridScheduler setupGridSchedulersLayered() {
348338
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
349339
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
350340
// New FFN tasks
351-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
352-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fusedFFN", configDimRowMajorGlobalWorker);
353-
341+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
354342
}
355343

356344
// Vocabulary worker configuration

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -962,59 +962,5 @@ public static void splitGateUpAndSiLU(FloatArray hb, FloatArray hbG, FloatArray
962962
hbU.set(i, siluGate * upVal);
963963
}
964964
}
965-
public static void fusedGateUpSiLUDown(KernelContext context,
966-
FloatArray input, // state.wrapXb
967-
FloatArray output, // state.wrapX (with residual)
968-
HalfFloatArray wUp, // weights.wUpLayered[layerIndex]
969-
HalfFloatArray wDown, // weights.wDownLayered[layerIndex]
970-
int dim, // config.dim()
971-
int hiddenDim, // config.hiddenDim()
972-
int localWorkGroupSize) {
973-
974-
int rowId = context.groupIdx; // Each workgroup computes one output dimension
975-
int localId = context.localIdx;
976-
977-
if (rowId >= dim) return;
978-
979-
float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
980-
float accumulator = 0.0f;
981-
982-
// Process hidden dimensions in chunks to maintain numerical stability
983-
for (int h = localId; h < hiddenDim; h += localWorkGroupSize) {
984-
// Step 1: Compute gate value (first half of wUp)
985-
float gateValue = 0.0f;
986-
for (int i = 0; i < dim; i++) {
987-
gateValue += wUp.get(h * dim + i).getFloat32() * input.get(i);
988-
}
989-
990-
// Step 2: Compute up value (second half of wUp)
991-
float upValue = 0.0f;
992-
for (int i = 0; i < dim; i++) {
993-
upValue += wUp.get((h + hiddenDim) * dim + i).getFloat32() * input.get(i);
994-
}
995-
996-
// Step 3: Apply SiLU to gate and multiply with up
997-
float siluGate = gateValue / (1.0f + TornadoMath.exp(-gateValue));
998-
float activated = siluGate * upValue;
999-
1000-
// Step 4: Apply down projection for this row
1001-
accumulator += wDown.get(rowId * hiddenDim + h).getFloat32() * activated;
1002-
}
1003965

1004-
// Reduce across workgroup
1005-
localSum[localId] = accumulator;
1006-
context.localBarrier();
1007-
1008-
for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
1009-
if (localId < stride) {
1010-
localSum[localId] += localSum[localId + stride];
1011-
}
1012-
context.localBarrier();
1013-
}
1014-
1015-
// Add residual connection
1016-
if (localId == 0) {
1017-
output.set(rowId, output.get(rowId) + localSum[0]);
1018-
}
1019-
}
1020966
}

0 commit comments

Comments
 (0)