Skip to content

Commit b8a9886

Browse files
authored
Working streaming tokenizer (#1210)
1 parent ffe4bfe commit b8a9886

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,42 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
104104
self.next_tokens_are_prompt = True
105105
self.stop_signal = None
106106
self.text_queue = queue.Queue()
107+
self.token_cache = []
108+
self.text_index_cache = []
107109

108-
def put(self, value):
110+
def put(self, values):
109111
if self.skip_prompt and self.next_tokens_are_prompt:
110112
self.next_tokens_are_prompt = False
111113
return
112-
# Can't batch this decode
113-
decoded_values = []
114-
for v in value:
115-
decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs))
116-
self.text_queue.put(decoded_values, self.timeout)
114+
output = []
115+
for i, v in enumerate(values):
116+
if len(self.token_cache) <= i:
117+
self.token_cache.append([])
118+
self.text_index_cache.append(0)
119+
token = v.tolist() # Returns a list or number
120+
if type(token) == list:
121+
self.token_cache[i].extend(token)
122+
else:
123+
self.token_cache[i].append(token)
124+
text = self.tokenizer.decode(self.token_cache[i], **self.decode_kwargs)
125+
if text.endswith("\n"):
126+
output.append(text[self.text_index_cache[i] :])
127+
self.token_cache[i] = []
128+
self.text_index_cache[i] = 0
129+
else:
130+
printable_text = text[self.text_index_cache[i] : text.rfind(" ") + 1]
131+
self.text_index_cache[i] += len(printable_text)
132+
output.append(printable_text)
133+
if any(output):
134+
self.text_queue.put(output, self.timeout)
117135

118136
def end(self):
119137
self.next_tokens_are_prompt = True
138+
output = []
139+
for i, tokens in enumerate(self.token_cache):
140+
text = self.tokenizer.decode(tokens, **self.decode_kwargs)
141+
output.append(text[self.text_index_cache[i] :])
142+
self.text_queue.put(output, self.timeout)
120143
self.text_queue.put(self.stop_signal, self.timeout)
121144

122145
def __iter__(self):
@@ -127,6 +150,7 @@ def __next__(self):
127150
if value != self.stop_signal:
128151
return value
129152

153+
130154
class GGMLPipeline(object):
131155
def __init__(self, model_name, **task):
132156
import ctransformers
@@ -245,7 +269,8 @@ def stream(self, input, **kwargs):
245269
generation_kwargs = None
246270
if self.task == "conversational":
247271
streamer = TextIteratorStreamer(
248-
self.tokenizer, skip_prompt=True, skip_special_tokens=True
272+
self.tokenizer,
273+
skip_prompt=True,
249274
)
250275
if "chat_template" in kwargs:
251276
input = self.tokenizer.apply_chat_template(
@@ -261,7 +286,7 @@ def stream(self, input, **kwargs):
261286
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
262287
generation_kwargs = dict(input, streamer=streamer, **kwargs)
263288
else:
264-
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
289+
streamer = TextIteratorStreamer(self.tokenizer)
265290
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
266291
self.model.device
267292
)

0 commit comments

Comments
 (0)