You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Reformat TornadoVMMasterPlan by re-organizing line breaks, updating parameter annotations, and consolidating documentation for improved readability and consistency.
Copy file name to clipboardExpand all lines: src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java
+29-34Lines changed: 29 additions & 34 deletions
Original file line number
Diff line number
Diff line change
@@ -39,14 +39,13 @@ public TornadoVMMasterPlan(State state, Model model) {
39
39
}
40
40
41
41
/**
42
-
* Initializes the TornadoVM plan for GPU acceleration with optional timing.
43
-
* This method handles:
44
-
* 1. Creation of the TornadoVM master plan
45
-
* 2. Warming up the JIT compiler for better performance
46
-
* 3. Copying read-only model weights to the GPU
42
+
* Initializes the TornadoVM plan for GPU acceleration with optional timing. This method handles: 1. Creation of the TornadoVM master plan 2. Warming up the JIT compiler for better performance 3.
43
+
* Copying read-only model weights to the GPU
47
44
*
48
-
* @param state The model state containing KV cache
49
-
* @param model The Llama model instance
45
+
* @param state
46
+
* The model state containing KV cache
47
+
* @param model
48
+
* The Llama model instance
50
49
* @return The initialized TornadoVMMasterPlan ready for inference
caseUNKNOWN -> thrownewUnsupportedOperationException("Unknown model type");
105
-
};
106
-
}
107
-
108
-
/**
109
-
* Determines whether the NVIDIA-specific scheduler should be used based on the current
110
-
* hardware backend and the model type.
96
+
* Determines whether the NVIDIA-specific scheduler should be used based on the current hardware backend and the model type.
111
97
* <p>
112
-
* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model
113
-
* is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is
114
-
* {@code MISTRAL}, the NVIDIA-specific scheduler should not be used.
98
+
* The scheduler is used only if the runtime is targeting an NVIDIA backend and the model is not of type {@code MISTRAL}. If either the hardware is not NVIDIA or the model is {@code MISTRAL}, the
99
+
* NVIDIA-specific scheduler should not be used.
115
100
*
116
-
* @param model the model whose type may affect the scheduler decision
101
+
* @param model
102
+
* the model whose type may affect the scheduler decision
117
103
* @return {@code true} if the NVIDIA-specific scheduler should be used; {@code false} otherwise
caseUNKNOWN -> thrownewUnsupportedOperationException("Unknown model type");
126
+
};
127
+
}
128
+
129
+
/**
130
+
* Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. This method processes the transformer layers in sequence for a particular token position in the context
134
131
* window.
135
132
*
136
133
* <p>The execution happens in three phases:
@@ -140,7 +137,6 @@ public static boolean shouldUseNvidiaScheduler(Model model) {
140
137
* <li>Final projection to logits using TornadoVM</li>
141
138
* </ol>
142
139
*
143
-
*
144
140
* @param position
145
141
* The current position in the sequence being processed
146
142
* @return FloatTensor containing the output logits for token prediction
@@ -183,7 +179,9 @@ private int getPreprocessingGraphIndex() {
183
179
184
180
/**
185
181
* Returns the graph index for the given transformer layer.
186
-
* @param layerIndex Index of the transformer layer (0-based)
182
+
*
183
+
* @param layerIndex
184
+
* Index of the transformer layer (0-based)
187
185
*/
188
186
privateintgetLayerGraphIndex(intlayerIndex) {
189
187
return1 + layerIndex;
@@ -196,8 +194,7 @@ private int getFinalLogitsGraphIndex() {
196
194
returntaskGraphs.size() - 1;
197
195
}
198
196
199
-
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration
200
-
/// just once to copy the data into the read-only data layer.
197
+
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
201
198
publicvoidforceCopyInReadOnlyDataLayered() {
202
199
// Execute all TornadoVM graphs
203
200
state.wrapX.init(0.0f);
@@ -216,9 +213,7 @@ public void forceCopyInReadOnlyDataLayered() {
216
213
}
217
214
218
215
/**
219
-
* Frees the device memory allocated for the TornadoVM execution plan.
220
-
* This method should be called when the execution plan is no longer needed
221
-
* to release resources and avoid memory leaks.
216
+
* Frees the device memory allocated for the TornadoVM execution plan. This method should be called when the execution plan is no longer needed to release resources and avoid memory leaks.
0 commit comments