21
21
* Low-level operations for model inference.
22
22
*
23
23
* <p>
24
- * This class provides core computational operations such as RMS normalization and
25
- * forward passes through model layers. It supports both CPU and GPU implementations.
24
+ * This class provides core computational operations such as RMS normalization and forward passes through model layers. It supports both CPU and GPU implementations.
26
25
* </p>
27
26
*
28
27
* <p>
@@ -312,16 +311,13 @@ public static FloatTensor forwardJavaQwen3(Model model, State state, int token,
312
311
}
313
312
314
313
public static FloatTensor forwardJavaPhi3 (Model model , Phi3State state , int token , int position ) {
315
- // a few convenience variables
316
314
Phi3Configuration config = (Phi3Configuration ) model .configuration ();
317
315
Phi3StandardWeights weights = (Phi3StandardWeights ) model .weights ();
318
316
int dim = config .dim ();
319
317
int headSize = config .headSize ();
320
318
int kvDim = (config .dim () * config .numberOfKeyValueHeads ()) / config .numberOfHeads ();
321
319
int kvMul = config .numberOfHeads () / config .numberOfKeyValueHeads (); // integer multiplier of the kv sharing in multiquery
322
320
float sqrtHeadSize = (float ) Math .sqrt (headSize );
323
- // dim=3072, headSize=96, kvDim=3072, kvMul=1
324
- // System.out.println(String.format("dim=%d, headSize=%d, kvDim=%d, kvMul=%d", dim, headSize, kvDim, kvMul));
325
321
326
322
// copy the token embedding into x
327
323
weights .token_embedding_table .copyTo (token * dim , state .x , 0 , dim );
@@ -331,22 +327,15 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
331
327
332
328
// forward all the layers
333
329
for (int l = 0 ; l < config .numberOfLayers (); l ++) {
334
- // attention rmsnorm
335
330
rmsnorm (state .xb , state .x , weights .rms_att_weight [l ], 0 , dim , config .rmsNormEps ());
336
331
337
- // qkv matmuls for this position
338
- // wqkv: (hidden_size, op_size)
339
332
weights .wqkv [l ].matmul (state .xb , state .qkv , opSize , dim );
340
- // query_pos = self.num_heads * self.head_dim
341
- // query_states = qkv[..., :query_pos]
342
333
state .qkv .copyTo (0 , state .q , 0 , dim );
343
334
// key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
344
335
state .qkv .copyTo (dim , state .k , 0 , config .numberOfKeyValueHeads () * headSize );
345
336
// value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
346
337
state .qkv .copyTo (dim + config .numberOfKeyValueHeads () * headSize , state .v , 0 , config .numberOfKeyValueHeads () * headSize );
347
338
348
- // RoPE relative positional encoding: complex-valued rotate q and k in each head
349
- // phi-3 uses RoPE-type neox, i.e. offset dim/2 instead of 1.
350
339
int dimHalf = headSize / 2 ;
351
340
for (int i = 0 ; i < dim ; i += 2 ) {
352
341
int head_dim = i % headSize ;
@@ -365,50 +354,31 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
365
354
}
366
355
367
356
// save key,value at this time step (position) to our kv cache
368
- //int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
369
357
state .k .copyTo (0 , state .keyCache [l ], position * kvDim , kvDim );
370
358
state .v .copyTo (0 , state .valueCache [l ], position * kvDim , kvDim );
371
359
372
360
int curLayer = l ;
373
361
374
- // multihead attention. iterate over all heads
375
362
Parallel .parallelFor (0 , config .numberOfHeads (), h -> {
376
- // get the query vector for this head
377
- // float* q = s.q + h * headSize;
378
363
int qOffset = h * headSize ;
379
364
380
- // attention scores for this head
381
- // float* att = s.att + h * config.seq_len;
382
365
int attOffset = h * config .contextLength ();
383
366
384
- // iterate over all timesteps, including the current one
385
367
for (int t = 0 ; t <= position ; t ++) {
386
- // get the key vector for this head and at this timestep
387
- // float* k = s.key_cache + loff + t * dim + h * headSize;
388
368
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul ) * headSize ;
389
- // calculate the attention score as the dot product of q and k
390
369
float score = state .q .dot (qOffset , state .keyCache [curLayer ], keyCacheOffset , headSize );
391
370
score /= sqrtHeadSize ;
392
- // save the score to the attention buffer
393
371
state .att .setFloat (attOffset + t , score );
394
372
}
395
373
396
- // softmax the scores to get attention weights, from 0..position inclusively
397
374
state .att .softmaxInPlace (attOffset , position + 1 );
398
375
399
- // weighted sum of the values, store back into xb
400
- // float* xb = s.xb + h * headSize;
401
376
int xbOffset = h * headSize ;
402
- // memset(xb, 0, headSize * sizeof(float));
403
377
state .xb .fillInPlace (xbOffset , headSize , 0f );
404
378
405
379
for (int t = 0 ; t <= position ; t ++) {
406
- // get the value vector for this head and at this timestep
407
- // float* v = s.value_cache + loff + t * dim + h * headSize;
408
380
int vOffset = /* loff + */ t * kvDim + (h / kvMul ) * headSize ;
409
- // get the attention weight for this timestep
410
381
float a = state .att .getFloat (attOffset + t );
411
- // accumulate the weighted value into xb
412
382
state .xb .saxpyInPlace (xbOffset , state .valueCache [curLayer ], vOffset , headSize , a );
413
383
}
414
384
});
@@ -419,31 +389,19 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
419
389
// residual connection back into x
420
390
state .x .addInPlace (state .xb2 );
421
391
422
- // ffn rmsnorm
423
392
rmsnorm (state .xb , state .x , weights .rms_ffn_weight [l ], 0 , dim , config .rmsNormEps ());
424
393
425
- // MLP in phi3:
426
- // up_states = self.gate_up_proj(hidden_states)
427
394
weights .wGateUp [l ].matmul (state .xb , state .hb , 2 * config .hiddenDim (), dim );
428
- // gate, up_states = up_states.chunk(2, dim=-1)
429
395
copyChunk (state .hb , state .hbG , 2 * config .hiddenDim (), config .hiddenDim (), 2 , 0 );
430
396
copyChunk (state .hb , state .hbU , 2 * config .hiddenDim (), config .hiddenDim (), 2 , 1 );
431
397
432
- // self.activation_fn(gate)
433
- // SwiGLU non-linearity
434
- // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
435
398
state .hbG .mapInPlace (value -> value / (float ) (1.0 + Math .exp (-value )));
436
399
437
- // up_states = up_states * self.activation_fn(gate)
438
- // elementwise multiply with w3(x)
439
400
state .hbU .multiplyInPlace (state .hbG );
440
401
441
- // self.down_proj(up_states)
442
402
weights .wDown [l ].matmul (state .hbU , state .xb , dim , config .hiddenDim ());
443
403
444
- // residual connection
445
404
state .x .addInPlace (state .xb );
446
-
447
405
}
448
406
449
407
// final rmsnorm
0 commit comments