@@ -104,19 +104,42 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
104
104
self .next_tokens_are_prompt = True
105
105
self .stop_signal = None
106
106
self .text_queue = queue .Queue ()
107
+ self .token_cache = []
108
+ self .text_index_cache = []
107
109
108
- def put (self , value ):
110
+ def put (self , values ):
109
111
if self .skip_prompt and self .next_tokens_are_prompt :
110
112
self .next_tokens_are_prompt = False
111
113
return
112
- # Can't batch this decode
113
- decoded_values = []
114
- for v in value :
115
- decoded_values .append (self .tokenizer .decode (v , ** self .decode_kwargs ))
116
- self .text_queue .put (decoded_values , self .timeout )
114
+ output = []
115
+ for i , v in enumerate (values ):
116
+ if len (self .token_cache ) <= i :
117
+ self .token_cache .append ([])
118
+ self .text_index_cache .append (0 )
119
+ token = v .tolist () # Returns a list or number
120
+ if type (token ) == list :
121
+ self .token_cache [i ].extend (token )
122
+ else :
123
+ self .token_cache [i ].append (token )
124
+ text = self .tokenizer .decode (self .token_cache [i ], ** self .decode_kwargs )
125
+ if text .endswith ("\n " ):
126
+ output .append (text [self .text_index_cache [i ] :])
127
+ self .token_cache [i ] = []
128
+ self .text_index_cache [i ] = 0
129
+ else :
130
+ printable_text = text [self .text_index_cache [i ] : text .rfind (" " ) + 1 ]
131
+ self .text_index_cache [i ] += len (printable_text )
132
+ output .append (printable_text )
133
+ if any (output ):
134
+ self .text_queue .put (output , self .timeout )
117
135
118
136
def end (self ):
119
137
self .next_tokens_are_prompt = True
138
+ output = []
139
+ for i , tokens in enumerate (self .token_cache ):
140
+ text = self .tokenizer .decode (tokens , ** self .decode_kwargs )
141
+ output .append (text [self .text_index_cache [i ] :])
142
+ self .text_queue .put (output , self .timeout )
120
143
self .text_queue .put (self .stop_signal , self .timeout )
121
144
122
145
def __iter__ (self ):
@@ -127,6 +150,7 @@ def __next__(self):
127
150
if value != self .stop_signal :
128
151
return value
129
152
153
+
130
154
class GGMLPipeline (object ):
131
155
def __init__ (self , model_name , ** task ):
132
156
import ctransformers
@@ -245,7 +269,8 @@ def stream(self, input, **kwargs):
245
269
generation_kwargs = None
246
270
if self .task == "conversational" :
247
271
streamer = TextIteratorStreamer (
248
- self .tokenizer , skip_prompt = True , skip_special_tokens = True
272
+ self .tokenizer ,
273
+ skip_prompt = True ,
249
274
)
250
275
if "chat_template" in kwargs :
251
276
input = self .tokenizer .apply_chat_template (
@@ -261,7 +286,7 @@ def stream(self, input, **kwargs):
261
286
input = self .tokenizer (input , return_tensors = "pt" ).to (self .model .device )
262
287
generation_kwargs = dict (input , streamer = streamer , ** kwargs )
263
288
else :
264
- streamer = TextIteratorStreamer (self .tokenizer , skip_special_tokens = True )
289
+ streamer = TextIteratorStreamer (self .tokenizer )
265
290
input = self .tokenizer (input , return_tensors = "pt" , padding = True ).to (
266
291
self .model .device
267
292
)
0 commit comments