Skip to content

Commit 34b3c79

Browse files
committed
Clean up InferenceCore by removing commented-out code, redundant lines, and unnecessary variables to improve readability and maintainability.
1 parent 9f6b99e commit 34b3c79

File tree

1 file changed

+1
-43
lines changed

1 file changed

+1
-43
lines changed

src/main/java/com/example/inference/InferenceCore.java

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
* Low-level operations for model inference.
2222
*
2323
* <p>
24-
* This class provides core computational operations such as RMS normalization and
25-
* forward passes through model layers. It supports both CPU and GPU implementations.
24+
* This class provides core computational operations such as RMS normalization and forward passes through model layers. It supports both CPU and GPU implementations.
2625
* </p>
2726
*
2827
* <p>
@@ -312,16 +311,13 @@ public static FloatTensor forwardJavaQwen3(Model model, State state, int token,
312311
}
313312

314313
public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int token, int position) {
315-
// a few convenience variables
316314
Phi3Configuration config = (Phi3Configuration) model.configuration();
317315
Phi3StandardWeights weights = (Phi3StandardWeights) model.weights();
318316
int dim = config.dim();
319317
int headSize = config.headSize();
320318
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
321319
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
322320
float sqrtHeadSize = (float) Math.sqrt(headSize);
323-
// dim=3072, headSize=96, kvDim=3072, kvMul=1
324-
// System.out.println(String.format("dim=%d, headSize=%d, kvDim=%d, kvMul=%d", dim, headSize, kvDim, kvMul));
325321

326322
// copy the token embedding into x
327323
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
@@ -331,22 +327,15 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
331327

332328
// forward all the layers
333329
for (int l = 0; l < config.numberOfLayers(); l++) {
334-
// attention rmsnorm
335330
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());
336331

337-
// qkv matmuls for this position
338-
// wqkv: (hidden_size, op_size)
339332
weights.wqkv[l].matmul(state.xb, state.qkv, opSize, dim);
340-
// query_pos = self.num_heads * self.head_dim
341-
// query_states = qkv[..., :query_pos]
342333
state.qkv.copyTo(0, state.q, 0, dim);
343334
// key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
344335
state.qkv.copyTo(dim, state.k, 0, config.numberOfKeyValueHeads() * headSize);
345336
// value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
346337
state.qkv.copyTo(dim + config.numberOfKeyValueHeads() * headSize, state.v, 0, config.numberOfKeyValueHeads() * headSize);
347338

348-
// RoPE relative positional encoding: complex-valued rotate q and k in each head
349-
// phi-3 uses RoPE-type neox, i.e. offset dim/2 instead of 1.
350339
int dimHalf = headSize / 2;
351340
for (int i = 0; i < dim; i += 2) {
352341
int head_dim = i % headSize;
@@ -365,50 +354,31 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
365354
}
366355

367356
// save key,value at this time step (position) to our kv cache
368-
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
369357
state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
370358
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
371359

372360
int curLayer = l;
373361

374-
// multihead attention. iterate over all heads
375362
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
376-
// get the query vector for this head
377-
// float* q = s.q + h * headSize;
378363
int qOffset = h * headSize;
379364

380-
// attention scores for this head
381-
// float* att = s.att + h * config.seq_len;
382365
int attOffset = h * config.contextLength();
383366

384-
// iterate over all timesteps, including the current one
385367
for (int t = 0; t <= position; t++) {
386-
// get the key vector for this head and at this timestep
387-
// float* k = s.key_cache + loff + t * dim + h * headSize;
388368
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
389-
// calculate the attention score as the dot product of q and k
390369
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
391370
score /= sqrtHeadSize;
392-
// save the score to the attention buffer
393371
state.att.setFloat(attOffset + t, score);
394372
}
395373

396-
// softmax the scores to get attention weights, from 0..position inclusively
397374
state.att.softmaxInPlace(attOffset, position + 1);
398375

399-
// weighted sum of the values, store back into xb
400-
// float* xb = s.xb + h * headSize;
401376
int xbOffset = h * headSize;
402-
// memset(xb, 0, headSize * sizeof(float));
403377
state.xb.fillInPlace(xbOffset, headSize, 0f);
404378

405379
for (int t = 0; t <= position; t++) {
406-
// get the value vector for this head and at this timestep
407-
// float* v = s.value_cache + loff + t * dim + h * headSize;
408380
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
409-
// get the attention weight for this timestep
410381
float a = state.att.getFloat(attOffset + t);
411-
// accumulate the weighted value into xb
412382
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
413383
}
414384
});
@@ -419,31 +389,19 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
419389
// residual connection back into x
420390
state.x.addInPlace(state.xb2);
421391

422-
// ffn rmsnorm
423392
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());
424393

425-
// MLP in phi3:
426-
// up_states = self.gate_up_proj(hidden_states)
427394
weights.wGateUp[l].matmul(state.xb, state.hb, 2 * config.hiddenDim(), dim);
428-
// gate, up_states = up_states.chunk(2, dim=-1)
429395
copyChunk(state.hb, state.hbG, 2 * config.hiddenDim(), config.hiddenDim(), 2, 0);
430396
copyChunk(state.hb, state.hbU, 2 * config.hiddenDim(), config.hiddenDim(), 2, 1);
431397

432-
// self.activation_fn(gate)
433-
// SwiGLU non-linearity
434-
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
435398
state.hbG.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
436399

437-
// up_states = up_states * self.activation_fn(gate)
438-
// elementwise multiply with w3(x)
439400
state.hbU.multiplyInPlace(state.hbG);
440401

441-
// self.down_proj(up_states)
442402
weights.wDown[l].matmul(state.hbU, state.xb, dim, config.hiddenDim());
443403

444-
// residual connection
445404
state.x.addInPlace(state.xb);
446-
447405
}
448406

449407
// final rmsnorm

0 commit comments

Comments
 (0)