Skip to content

Commit 50a0909

Browse files
committed
Fix worker grid configuration in Phi3TornadoVMLayerPlanner for correct local work settings
1 parent 80632f0 commit 50a0909

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/main/java/com/example/tornadovm/Phi3TornadoVMLayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ private GridScheduler setupGridSchedulersLayered() {
266266

267267
int qkvmatmulDimRowMajorGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC;
268268
WorkerGrid qkvDimRowMajorGlobalWorker = new WorkerGrid1D(qkvmatmulDimRowMajorGlobal);
269-
configDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
269+
qkvDimRowMajorGlobalWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC, 1, 1);
270270

271271

272272
// config.kvDim Worker for Row major access
@@ -346,7 +346,7 @@ private GridScheduler setupGridSchedulersLayered() {
346346
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope", ropeWorker);
347347
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
348348
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wDown", configDimRowMajorGlobalWorker);
349-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateup", wgetHiddenDimRowMajorWorker);
349+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".wGateUp", wgetHiddenDimRowMajorWorker);
350350
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
351351
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
352352
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);

0 commit comments

Comments
 (0)