Skip to content

Commit 8ffc282

Browse files
committed
Reformat TornadoVMMasterPlan by re-organizing line breaks, updating parameter annotations, and consolidating documentation for improved readability and consistency.
1 parent 0e1c68c commit 8ffc282

File tree

1 file changed

+29
-34
lines changed

1 file changed

+29
-34
lines changed

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,13 @@ public TornadoVMMasterPlan(State state, Model model) {
3939
}
4040

4141
/**
42-
* Initializes the TornadoVM plan for GPU acceleration with optional timing.
43-
* This method handles:
44-
* 1. Creation of the TornadoVM master plan
45-
* 2. Warming up the JIT compiler for better performance
46-
* 3. Copying read-only model weights to the GPU
42+
* Initializes the TornadoVM plan for GPU acceleration with optional timing. This method handles: 1. Creation of the TornadoVM master plan 2. Warming up the JIT compiler for better performance 3.
43+
* Copying read-only model weights to the GPU
4744
*
48-
* @param state The model state containing KV cache
49-
* @param model The Llama model instance
45+
* @param state
46+
* The model state containing KV cache
47+
* @param model
48+
* The Llama model instance
5049
* @return The initialized TornadoVMMasterPlan ready for inference
5150
*/
5251
public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
@@ -94,26 +93,13 @@ public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model mod
9493
}
9594

9695
/**
97-
* Dispatcher method to select the TornadoVMLayerPlanner for the model.
98-
*/
99-
TornadoVMLayerPlanner createPlanner(State state, Model model) {
100-
return switch (model.getModelType()) {
101-
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
102-
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
103-
case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
104-
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
105-
};
106-
}
107-
108-
/**
109-
* Determines whether the NVIDIA-specific scheduler should be used based on the current
110-
* hardware backend and the model type.
96+
* Determines whether the NVIDIA-specific scheduler should be used based on the current hardware backend and the model type.
11197
* <p>
112-
* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model
113-
* is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is
114-
* {@code MISTRAL}, the NVIDIA-specific scheduler should not be used.
98+
* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the
99+
* NVIDIA-specific scheduler should not be used.
115100
*
116-
* @param model the model whose type may affect the scheduler decision
101+
* @param model
102+
* the model whose type may affect the scheduler decision
117103
* @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise
118104
*/
119105
public static boolean shouldUseNvidiaScheduler(Model model) {
@@ -129,8 +115,19 @@ public static boolean shouldUseNvidiaScheduler(Model model) {
129115
}
130116

131117
/**
132-
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration.
133-
*This method processes the transformer layers in sequence for a particular token position in the context
118+
* Dispatcher method to select the TornadoVMLayerPlanner for the model.
119+
*/
120+
TornadoVMLayerPlanner createPlanner(State state, Model model) {
121+
return switch (model.getModelType()) {
122+
case LLAMA_3, MISTRAL -> new TornadoVMLayerPlanner(state, model);
123+
case QWEN_3 -> new Qwen3TornadoVMLayerPlanner((Qwen3State) state, model);
124+
case PHI_3 -> new Phi3TornadoVMLayerPlanner((Phi3State) state, model);
125+
case UNKNOWN -> throw new UnsupportedOperationException("Unknown model type");
126+
};
127+
}
128+
129+
/**
130+
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context
134131
* window.
135132
*
136133
* <p>The execution happens in three phases:
@@ -140,7 +137,6 @@ public static boolean shouldUseNvidiaScheduler(Model model) {
140137
* <li>Final projection to logits using TornadoVM</li>
141138
* </ol>
142139
*
143-
*
144140
* @param position
145141
* The current position in the sequence being processed
146142
* @return FloatTensor containing the output logits for token prediction
@@ -183,7 +179,9 @@ private int getPreprocessingGraphIndex() {
183179

184180
/**
185181
* Returns the graph index for the given transformer layer.
186-
* @param layerIndex Index of the transformer layer (0-based)
182+
*
183+
* @param layerIndex
184+
* Index of the transformer layer (0-based)
187185
*/
188186
private int getLayerGraphIndex(int layerIndex) {
189187
return 1 + layerIndex;
@@ -196,8 +194,7 @@ private int getFinalLogitsGraphIndex() {
196194
return taskGraphs.size() - 1;
197195
}
198196

199-
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration
200-
/// just once to copy the data into the read-only data layer.
197+
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
201198
public void forceCopyInReadOnlyDataLayered() {
202199
// Execute all TornadoVM graphs
203200
state.wrapX.init(0.0f);
@@ -216,9 +213,7 @@ public void forceCopyInReadOnlyDataLayered() {
216213
}
217214

218215
/**
219-
* Frees the device memory allocated for the TornadoVM execution plan.
220-
* This method should be called when the execution plan is no longer needed
221-
* to release resources and avoid memory leaks.
216+
* Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks.
222217
*/
223218
public void freeTornadoExecutionPlan() {
224219
executionPlan.freeDeviceMemory();

0 commit comments

Comments
 (0)