41
41
TrainingArguments ,
42
42
Trainer ,
43
43
TextStreamer ,
44
- Conversation
44
+ Conversation ,
45
45
)
46
46
from threading import Thread
47
47
from typing import Optional
@@ -95,24 +95,34 @@ def ensure_device(kwargs):
95
95
else :
96
96
kwargs ["device" ] = "cpu"
97
97
98
- # A copy of HuggingFace's with small changes in the __next__ to not raise an exception
99
- class TextIteratorStreamer (TextStreamer ):
100
- def __init__ (
101
- self , tokenizer , skip_prompt = False , timeout = None , ** decode_kwargs
102
- ):
103
- super ().__init__ (tokenizer , skip_prompt , ** decode_kwargs )
104
- self .text_queue = queue .Queue ()
105
- self .stop_signal = None
98
+
99
+ # Follows BaseStreamer template from transformers library
100
+ class TextIteratorStreamer :
101
+ def __init__ (self , tokenizer , skip_prompt = False , timeout = None , ** decode_kwargs ):
102
+ self .tokenizer = tokenizer
103
+ self .skip_prompt = skip_prompt
106
104
self .timeout = timeout
105
+ self .decode_kwargs = decode_kwargs
106
+ self .next_tokens_are_prompt = True
107
+ self .stop_signal = None
108
+ self .text_queue = queue .Queue ()
107
109
108
- def on_finalized_text (self , text : str , stream_end : bool = False ):
109
- """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
110
- self .text_queue .put (text , timeout = self .timeout )
111
- if stream_end :
112
- self .text_queue .put (self .stop_signal , timeout = self .timeout )
110
+ def put (self , value ):
111
+ if self .skip_prompt and self .next_tokens_are_prompt :
112
+ self .next_tokens_are_prompt = False
113
+ return
114
+ # Can't batch this decode
115
+ decoded_values = []
116
+ for v in value :
117
+ decoded_values .append (self .tokenizer .decode (v , ** self .decode_kwargs ))
118
+ self .text_queue .put (decoded_values , self .timeout )
119
+
120
+ def end (self ):
121
+ self .next_tokens_are_prompt = True
122
+ self .text_queue .put (self .stop_signal , self .timeout )
113
123
114
124
def __iter__ (self ):
115
- return self
125
+ self
116
126
117
127
def __next__ (self ):
118
128
value = self .text_queue .get (timeout = self .timeout )
@@ -215,6 +225,18 @@ def __init__(self, model_name, **kwargs):
215
225
# to the model constructor, so we construct the model/tokenizer manually if possible,
216
226
# but that is only possible when the task is passed in, since if you pass the model
217
227
# to the pipeline constructor, the task will no longer be inferred from the default...
228
+
229
+ # We want to create a text-generation pipeline if it is a conversational task
230
+ self .conversational = False
231
+ if "task" in kwargs and kwargs ["task" ] == "conversational" :
232
+ self .conversational = True
233
+ kwargs ["task" ] = "text-generation"
234
+
235
+ # Tokens can either be left or right padded depending on the architecture
236
+ padding_side = "right"
237
+ if "padding_side" in kwargs :
238
+ padding_side = kwargs .pop ("padding_side" )
239
+
218
240
if (
219
241
"task" in kwargs
220
242
and model_name is not None
@@ -224,8 +246,7 @@ def __init__(self, model_name, **kwargs):
224
246
"question-answering" ,
225
247
"summarization" ,
226
248
"translation" ,
227
- "text-generation" ,
228
- "conversational"
249
+ "text-generation"
229
250
]
230
251
):
231
252
self .task = kwargs .pop ("task" )
@@ -240,56 +261,75 @@ def __init__(self, model_name, **kwargs):
240
261
)
241
262
elif self .task == "summarization" or self .task == "translation" :
242
263
self .model = AutoModelForSeq2SeqLM .from_pretrained (model_name , ** kwargs )
243
- elif self .task == "text-generation" or self . task == "conversational" :
264
+ elif self .task == "text-generation" :
244
265
self .model = AutoModelForCausalLM .from_pretrained (model_name , ** kwargs )
245
266
else :
246
267
raise PgMLException (f"Unhandled task: { self .task } " )
247
268
248
269
if "use_auth_token" in kwargs :
249
270
self .tokenizer = AutoTokenizer .from_pretrained (
250
- model_name , use_auth_token = kwargs ["use_auth_token" ]
271
+ model_name , use_auth_token = kwargs ["use_auth_token" ], padding_side = padding_side
251
272
)
252
273
else :
253
- self .tokenizer = AutoTokenizer .from_pretrained (model_name )
274
+ self .tokenizer = AutoTokenizer .from_pretrained (model_name , padding_side = padding_side )
254
275
255
276
self .pipe = transformers .pipeline (
256
277
self .task ,
257
278
model = self .model ,
258
279
tokenizer = self .tokenizer ,
259
280
)
260
281
else :
261
- self .pipe = transformers .pipeline (** kwargs )
282
+ self .pipe = transformers .pipeline (** kwargs , padding_side = padding_side )
283
+ self .tokenizer = self .pipe .tokenizer
262
284
self .task = self .pipe .task
263
285
self .model = self .pipe .model
264
- if self .pipe .tokenizer is None :
265
- self .pipe .tokenizer = AutoTokenizer .from_pretrained (
266
- self .model .name_or_path
267
- )
268
- self .tokenizer = self .pipe .tokenizer
286
+
287
+ # Make sure we set the pad token if it does not exist
288
+ if self .tokenizer .pad_token is None :
289
+ self .tokenizer .pad_token = self .tokenizer .eos_token
269
290
270
291
def stream (self , inputs , ** kwargs ):
271
292
streamer = None
272
293
generation_kwargs = None
273
- if self .task == "conversational" :
274
- streamer = TextIteratorStreamer (self .tokenizer , skip_prompt = True )
275
- inputs = tokenized_chat = self .tokenizer .apply_chat_template (inputs , tokenize = True , add_generation_prompt = True , return_tensors = "pt" ).to (self .model .device )
276
- generation_kwargs = dict (inputs = inputs , streamer = streamer , ** kwargs )
294
+ # Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it
295
+ if self .conversational :
296
+ streamer = TextIteratorStreamer (
297
+ self .tokenizer , skip_prompt = True , skip_special_tokens = True
298
+ )
299
+ templated_inputs = []
300
+ for input in inputs :
301
+ templated_inputs .append (
302
+ self .tokenizer .apply_chat_template (
303
+ input , add_generation_prompt = True , tokenize = False
304
+ )
305
+ )
306
+ inputs = self .tokenizer (
307
+ templated_inputs , return_tensors = "pt" , padding = True
308
+ ).to (self .model .device )
309
+ generation_kwargs = dict (inputs , streamer = streamer , ** kwargs )
277
310
else :
278
- streamer = TextIteratorStreamer (self .tokenizer )
279
- inputs = self .tokenizer ([inputs ], return_tensors = "pt" ).to (self .model .device )
311
+ streamer = TextIteratorStreamer (self .tokenizer , skip_special_tokens = True )
312
+ inputs = self .tokenizer (inputs , return_tensors = "pt" , padding = True ).to (
313
+ self .model .device
314
+ )
280
315
generation_kwargs = dict (inputs , streamer = streamer , ** kwargs )
316
+ print ("\n \n " , file = sys .stderr )
317
+ print (inputs , file = sys .stderr )
318
+ print ("\n \n " , file = sys .stderr )
281
319
thread = Thread (target = self .model .generate , kwargs = generation_kwargs )
282
320
thread .start ()
283
321
return streamer
284
322
285
323
def __call__ (self , inputs , ** kwargs ):
286
- if self .task == "conversational" :
287
- outputs = []
288
- for conversation in inputs :
289
- conversation = Conversation (conversation )
290
- conversation = self .pipe (conversation , ** kwargs )
291
- outputs .append (conversation .generated_responses [- 1 ])
292
- return outputs
324
+ if self .conversational :
325
+ templated_inputs = []
326
+ for input in inputs :
327
+ templated_inputs .append (
328
+ self .tokenizer .apply_chat_template (
329
+ input , add_generation_prompt = True , tokenize = False
330
+ )
331
+ )
332
+ return self .pipe (templated_inputs , return_full_text = False , ** kwargs )
293
333
else :
294
334
return self .pipe (inputs , ** kwargs )
295
335
@@ -320,7 +360,11 @@ def create_pipeline(task):
320
360
lower = None
321
361
if lower and ("-ggml" in lower or "-gguf" in lower ):
322
362
pipe = GGMLPipeline (model_name , ** task )
323
- elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama" ):
363
+ elif (
364
+ lower
365
+ and "-gptq" in lower
366
+ and not (model_type == "mistral" or model_type == "llama" )
367
+ ):
324
368
pipe = GPTQPipeline (model_name , ** task )
325
369
else :
326
370
try :
0 commit comments