2
2
3
3
import com .example .auxiliary .LastRunMetrics ;
4
4
import com .example .inference .sampler .Sampler ;
5
- import com .example .inference .state .Phi3State ;
6
5
import com .example .inference .state .State ;
7
6
import com .example .model .Configuration ;
8
7
import com .example .model .Model ;
9
8
import com .example .tokenizer .impl .Tokenizer ;
10
9
import com .example .tornadovm .TornadoVMMasterPlan ;
11
10
import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
12
11
12
+ import java .io .ByteArrayOutputStream ;
13
13
import java .util .ArrayList ;
14
14
import java .util .List ;
15
15
import java .util .Set ;
@@ -217,7 +217,57 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
217
217
218
218
public static List <Integer > generateTokensPhi3 (Model model , State state , int startPosition , List <Integer > promptTokens , Set <Integer > stopTokens , int maxTokens , Sampler sampler , boolean echo ,
219
219
IntConsumer onTokenGenerated ) {
220
- return null ;
220
+
221
+ long startNanos = System .nanoTime ();
222
+ if (maxTokens < 0 || model .configuration ().contextLength () < maxTokens ) {
223
+ maxTokens = model .configuration ().contextLength ();
224
+ }
225
+ List <Integer > generatedTokens = new ArrayList <>(maxTokens );
226
+ int token = state .latestToken ; // BOS?
227
+ int nextToken ;
228
+ int promptIndex = 0 ;
229
+ ByteArrayOutputStream baos = new ByteArrayOutputStream (5 );
230
+ for (int position = startPosition ; position < maxTokens ; ++position ) {
231
+
232
+ model .forward (state , token , position );
233
+ if (promptIndex < promptTokens .size ()) {
234
+ // Force-pick token from prompt.
235
+ nextToken = promptTokens .get (promptIndex ++);
236
+ if (echo ) {
237
+ // log prompt token (different color?)
238
+ System .out .println ("NextToken: " + nextToken );
239
+ //System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
240
+ String decoded = model .tokenizer ().decode (List .of (nextToken ));
241
+ System .err .print (Tokenizer .replaceControlCharacters (model .tokenizer ().decode (List .of (nextToken ))));
242
+
243
+ // System.err.print(de(decoded, baos));
244
+ }
245
+ } else {
246
+ nextToken = sampler .sampleToken (state .logits );
247
+ if (echo ) {
248
+ // log inferred token
249
+ System .err .print (Tokenizer .replaceControlCharacters (model .tokenizer ().decode (List .of (nextToken ))));
250
+ }
251
+ generatedTokens .add (nextToken );
252
+ if (onTokenGenerated != null ) {
253
+ onTokenGenerated .accept (nextToken );
254
+ }
255
+ if (stopTokens .contains (nextToken )) {
256
+ break ;
257
+ }
258
+ }
259
+ state .latestToken = token = nextToken ;
260
+ if (position == 2000 ) {
261
+ break ;
262
+ }
263
+ }
264
+
265
+ long elapsedNanos = System .nanoTime () - startNanos ;
266
+ int totalTokens = promptIndex + generatedTokens .size ();
267
+ System .err .printf ("%n%.2f tokens/s (%d)%n" , totalTokens / (elapsedNanos / 1_000_000_000.0 ), totalTokens );
268
+
269
+ return generatedTokens ;
270
+
221
271
}
222
272
223
273
public static List <Integer > generateTokensGPULlama (Model model , State state , int startPosition , List <Integer > promptTokens , Set <Integer > stopTokens , int maxTokens , Sampler sampler , boolean echo ,
@@ -406,4 +456,5 @@ public static List<Integer> generateTokensGPUPhi3(Model model, State state, int
406
456
IntConsumer onTokenGenerated , TornadoVMMasterPlan tornadoVMPlan ) {
407
457
return null ;
408
458
}
409
- }
459
+
460
+ }
0 commit comments