Skip to content

Commit 8adb2bb

Browse files
committed
Complete generateTokensPhi3 implementation and refine Phi3 tokenizer logic
1 parent 104a5bb commit 8adb2bb

File tree

3 files changed

+57
-6
lines changed

3 files changed

+57
-6
lines changed

src/main/java/com/example/inference/InferenceEngine.java

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import com.example.auxiliary.LastRunMetrics;
44
import com.example.inference.sampler.Sampler;
5-
import com.example.inference.state.Phi3State;
65
import com.example.inference.state.State;
76
import com.example.model.Configuration;
87
import com.example.model.Model;
98
import com.example.tokenizer.impl.Tokenizer;
109
import com.example.tornadovm.TornadoVMMasterPlan;
1110
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1211

12+
import java.io.ByteArrayOutputStream;
1313
import java.util.ArrayList;
1414
import java.util.List;
1515
import java.util.Set;
@@ -217,7 +217,57 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
217217

218218
public static List<Integer> generateTokensPhi3(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
219219
IntConsumer onTokenGenerated) {
220-
return null;
220+
221+
long startNanos = System.nanoTime();
222+
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
223+
maxTokens = model.configuration().contextLength();
224+
}
225+
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
226+
int token = state.latestToken; // BOS?
227+
int nextToken;
228+
int promptIndex = 0;
229+
ByteArrayOutputStream baos = new ByteArrayOutputStream(5);
230+
for (int position = startPosition; position < maxTokens; ++position) {
231+
232+
model.forward(state, token, position);
233+
if (promptIndex < promptTokens.size()) {
234+
// Force-pick token from prompt.
235+
nextToken = promptTokens.get(promptIndex++);
236+
if (echo) {
237+
// log prompt token (different color?)
238+
System.out.println("NextToken: " + nextToken);
239+
//System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
240+
String decoded = model.tokenizer().decode(List.of(nextToken));
241+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
242+
243+
// System.err.print(de(decoded, baos));
244+
}
245+
} else {
246+
nextToken = sampler.sampleToken(state.logits);
247+
if (echo) {
248+
// log inferred token
249+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
250+
}
251+
generatedTokens.add(nextToken);
252+
if (onTokenGenerated != null) {
253+
onTokenGenerated.accept(nextToken);
254+
}
255+
if (stopTokens.contains(nextToken)) {
256+
break;
257+
}
258+
}
259+
state.latestToken = token = nextToken;
260+
if (position == 2000) {
261+
break;
262+
}
263+
}
264+
265+
long elapsedNanos = System.nanoTime() - startNanos;
266+
int totalTokens = promptIndex + generatedTokens.size();
267+
System.err.printf("%n%.2f tokens/s (%d)%n", totalTokens / (elapsedNanos / 1_000_000_000.0), totalTokens);
268+
269+
return generatedTokens;
270+
221271
}
222272

223273
public static List<Integer> generateTokensGPULlama(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
@@ -406,4 +456,5 @@ public static List<Integer> generateTokensGPUPhi3(Model model, State state, int
406456
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
407457
return null;
408458
}
409-
}
459+
460+
}

src/main/java/com/example/model/phi3/Phi3.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public void forward(State state, int token, int position) {
6666
@Override
6767
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
6868
IntConsumer onTokenGenerated) {
69-
return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
69+
return InferenceEngine.generateTokensPhi3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
7070
}
7171

7272
@Override

src/main/java/com/example/tokenizer/impl/Phi3Tokenizer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,12 @@ public Map<String, Integer> getSpecialTokens() {
7373

7474
@Override
7575
public boolean isSpecialToken(int tokenIndex) {
76-
return false;
76+
return specialTokens.containsValue(tokenIndex);
7777
}
7878

7979
@Override
8080
public boolean shouldDisplayToken(int token) {
81-
return false;
81+
return !isSpecialToken(token);
8282
}
8383

8484
@Override

0 commit comments

Comments
 (0)