Skip to content

Commit 3e1887d

Browse files
committed
Refactor Phi3TornadoVMLayerPlanner by consolidating FFN tasks into fusedGateUpSiLUDownOptimized, removing redundant operations, and updating worker grid mappings for improved efficiency and maintainability.
1 parent c5b53b2 commit 3e1887d

File tree

2 files changed

+103
-29
lines changed

2 files changed

+103
-29
lines changed

src/main/java/com/example/tornadovm/Phi3TornadoVMLayerPlanner.java

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
6666
.task("splitQKV", TransformerComputeKernelsLayered::splitQKV,
6767
state.wrapQkv, state.wrapQ, state.wrapK, state.wrapV,
6868
config.dim(), config.headSize() * config.numberOfKeyValueHeads())
69-
// .task("copyQ", TransformerComputeKernelsLayered::copyTo,
70-
// state.wrapQkv, 0, state.wrapQ,0, config.dim())
71-
// .task("copyK", TransformerComputeKernelsLayered::copyTo,
72-
// state.wrapQkv, config.dim(), state.wrapK, 0, config.headSize() * config.numberOfKeyValueHeads())
73-
// .task("copyV", TransformerComputeKernelsLayered::copyTo,
74-
// state.wrapQkv, config.dim() + config.headSize() * config.numberOfKeyValueHeads(),
75-
// state.wrapV, 0, config.headSize() * config.numberOfKeyValueHeads())
7669
.task("rope", TransformerComputeKernelsLayered::ropeRotationPhi3,context,
7770
state.positionHolder, state.wrapQ, state.wrapK, config.kvDim(),
7871
config.headSize())
@@ -88,13 +81,22 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
8881
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
8982
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
9083
state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN)
91-
.task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
92-
state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
93-
// Copy gate chunk: hb[0:hiddenDim] -> hbG[0:hiddenDim]
94-
.task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
95-
state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
96-
.task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
97-
state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
84+
// Before (3 tasks):
85+
// .task("wGateUp", ...)
86+
// .task("gateUpSiLU", ...)
87+
// .task("wDown", ...)
88+
89+
// After (1 fused task):
90+
.task("fusedFFN", TransformerComputeKernelsLayered::fusedGateUpSiLUDownOptimized, context,
91+
state.wrapXb, state.wrapX, weights.wUpLayered[layerIndex],
92+
weights.wDownLayered[layerIndex], config.dim(), config.hiddenDim(),
93+
LOCAL_WORK_GROUP_SIZE_ALLOC)
94+
// .task("wGateUp", TransformerComputeKernelsLayered::matrixVectorGeneric, context,
95+
// state.wrapXb, state.wrapHb, weights.wUpLayered[layerIndex], config.dim(), 2 * config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
96+
// .task("gateUpSiLU", TransformerComputeKernelsLayered::splitGateUpAndSiLU,
97+
// state.wrapHb, state.wrapHbG, state.wrapHbU, config.hiddenDim())
98+
// .task("wDown", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
99+
// state.wrapHbU, state.wrapX, weights.wDownLayered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
98100
.persistOnDevice(
99101
state.wrapX
100102
);
@@ -334,13 +336,7 @@ private GridScheduler setupGridSchedulersLayered() {
334336
tornadoForwardScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker);
335337
for (int i = 0; i < config.numberOfLayers(); i++) {
336338
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkvmatmul", qkvDimRowMajorGlobalWorker);
337-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyQ", copyQWorker);
338-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyK", copyKWorker);
339-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyV", copyVWorker);
340-
341339
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".splitQKV", splitQKVWorker);
342-
343-
344340
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
345341
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
346342
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker);
@@ -352,12 +348,8 @@ private GridScheduler setupGridSchedulersLayered() {
352348
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
353349
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
354350
// New FFN tasks
355-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
356-
357-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyGate", hiddenDimWorker);
358-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyUp", hiddenDimWorker);
359-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".siluActivation", hiddenDimWorker);
360-
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gatedMultiply", hiddenDimWorker);
351+
// tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".gateUpSiLU", splitGateUpSiLUWorker);
352+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fusedFFN", configDimRowMajorGlobalWorker);
361353

362354
}
363355

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ public static void copyTo(FloatArray src, int srcOffset, FloatArray dest, int de
144144
}
145145
}
146146

147-
public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v,
148-
int dimQ, int dimKV) {
147+
public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArray v, int dimQ, int dimKV) {
149148
int totalSize = dimQ + 2 * dimKV;
150149

151150
for (@Parallel int i = 0; i < totalSize; i++) {
@@ -949,7 +948,6 @@ public static void reductionFinalNormalization(KernelContext context, FloatArray
949948
}
950949
}
951950

952-
953951
public static void splitGateUpAndSiLU(FloatArray hb, FloatArray hbG, FloatArray hbU, int hiddenDim) {
954952
// Copy and apply SiLU to gate in one pass
955953
for (@Parallel int i = 0; i < hiddenDim; i++) {
@@ -964,4 +962,88 @@ public static void splitGateUpAndSiLU(FloatArray hb, FloatArray hbG, FloatArray
964962
hbU.set(i, siluGate * upVal);
965963
}
966964
}
965+
966+
public static void fusedGateUpSiLUDownOptimized(KernelContext context,
967+
FloatArray input,
968+
FloatArray output,
969+
HalfFloatArray wUp,
970+
HalfFloatArray wDown,
971+
int dim,
972+
int hiddenDim,
973+
int localWorkGroupSize) {
974+
975+
int rowId = context.groupIdx;
976+
int localId = context.localIdx;
977+
978+
if (rowId >= dim) return;
979+
980+
// Shared memory for input vector (reused across all hidden computations)
981+
float[] sharedInput = context.allocateFloatLocalArray(dim);
982+
float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
983+
984+
// Cooperatively load input into shared memory
985+
for (int i = localId; i < dim; i += localWorkGroupSize) {
986+
sharedInput[i] = input.get(i);
987+
}
988+
context.localBarrier();
989+
990+
float accumulator = 0.0f;
991+
992+
// Each thread processes multiple hidden dimensions
993+
for (int h = localId; h < hiddenDim; h += localWorkGroupSize) {
994+
// Compute gate and up values using shared input
995+
float gateValue = 0.0f;
996+
float upValue = 0.0f;
997+
998+
int gateRowOffset = h * dim;
999+
int upRowOffset = (h + hiddenDim) * dim;
1000+
1001+
// Unrolled loop for better performance
1002+
int i = 0;
1003+
for (; i < dim - 3; i += 4) {
1004+
float in0 = sharedInput[i];
1005+
float in1 = sharedInput[i + 1];
1006+
float in2 = sharedInput[i + 2];
1007+
float in3 = sharedInput[i + 3];
1008+
1009+
gateValue += wUp.get(gateRowOffset + i).getFloat32() * in0;
1010+
gateValue += wUp.get(gateRowOffset + i + 1).getFloat32() * in1;
1011+
gateValue += wUp.get(gateRowOffset + i + 2).getFloat32() * in2;
1012+
gateValue += wUp.get(gateRowOffset + i + 3).getFloat32() * in3;
1013+
1014+
upValue += wUp.get(upRowOffset + i).getFloat32() * in0;
1015+
upValue += wUp.get(upRowOffset + i + 1).getFloat32() * in1;
1016+
upValue += wUp.get(upRowOffset + i + 2).getFloat32() * in2;
1017+
upValue += wUp.get(upRowOffset + i + 3).getFloat32() * in3;
1018+
}
1019+
1020+
// Handle remainder
1021+
for (; i < dim; i++) {
1022+
float inVal = sharedInput[i];
1023+
gateValue += wUp.get(gateRowOffset + i).getFloat32() * inVal;
1024+
upValue += wUp.get(upRowOffset + i).getFloat32() * inVal;
1025+
}
1026+
1027+
// Apply SiLU and multiply
1028+
float activated = (gateValue / (1.0f + TornadoMath.exp(-gateValue))) * upValue;
1029+
1030+
// Apply down projection
1031+
accumulator += wDown.get(rowId * hiddenDim + h).getFloat32() * activated;
1032+
}
1033+
1034+
// Final reduction and residual add
1035+
localSum[localId] = accumulator;
1036+
context.localBarrier();
1037+
1038+
for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
1039+
if (localId < stride) {
1040+
localSum[localId] += localSum[localId + stride];
1041+
}
1042+
context.localBarrier();
1043+
}
1044+
1045+
if (localId == 0) {
1046+
output.set(rowId, output.get(rowId) + localSum[0]);
1047+
}
1048+
}
9671049
}

0 commit comments

Comments
 (0)