@@ -226,17 +226,6 @@ def __init__(self, model_name, **kwargs):
226
226
# but that is only possible when the task is passed in, since if you pass the model
227
227
# to the pipeline constructor, the task will no longer be inferred from the default...
228
228
229
- # We want to create a text-generation pipeline if it is a conversational task
230
- self .conversational = False
231
- if "task" in kwargs and kwargs ["task" ] == "conversational" :
232
- self .conversational = True
233
- kwargs ["task" ] = "text-generation"
234
-
235
- # Tokens can either be left or right padded depending on the architecture
236
- padding_side = "right"
237
- if "padding_side" in kwargs :
238
- padding_side = kwargs .pop ("padding_side" )
239
-
240
229
if (
241
230
"task" in kwargs
242
231
and model_name is not None
@@ -246,7 +235,8 @@ def __init__(self, model_name, **kwargs):
246
235
"question-answering" ,
247
236
"summarization" ,
248
237
"translation" ,
249
- "text-generation"
238
+ "text-generation" ,
239
+ "conversational" ,
250
240
]
251
241
):
252
242
self .task = kwargs .pop ("task" )
@@ -261,25 +251,25 @@ def __init__(self, model_name, **kwargs):
261
251
)
262
252
elif self .task == "summarization" or self .task == "translation" :
263
253
self .model = AutoModelForSeq2SeqLM .from_pretrained (model_name , ** kwargs )
264
- elif self .task == "text-generation" :
254
+ elif self .task == "text-generation" or self . task == "conversational" :
265
255
self .model = AutoModelForCausalLM .from_pretrained (model_name , ** kwargs )
266
256
else :
267
257
raise PgMLException (f"Unhandled task: { self .task } " )
268
258
269
259
if "use_auth_token" in kwargs :
270
260
self .tokenizer = AutoTokenizer .from_pretrained (
271
- model_name , use_auth_token = kwargs ["use_auth_token" ], padding_side = padding_side
261
+ model_name , use_auth_token = kwargs ["use_auth_token" ]
272
262
)
273
263
else :
274
- self .tokenizer = AutoTokenizer .from_pretrained (model_name , padding_side = padding_side )
264
+ self .tokenizer = AutoTokenizer .from_pretrained (model_name )
275
265
276
266
self .pipe = transformers .pipeline (
277
267
self .task ,
278
268
model = self .model ,
279
269
tokenizer = self .tokenizer ,
280
270
)
281
271
else :
282
- self .pipe = transformers .pipeline (** kwargs , padding_side = padding_side )
272
+ self .pipe = transformers .pipeline (** kwargs )
283
273
self .tokenizer = self .pipe .tokenizer
284
274
self .task = self .pipe .task
285
275
self .model = self .pipe .model
@@ -288,48 +278,53 @@ def __init__(self, model_name, **kwargs):
288
278
if self .tokenizer .pad_token is None :
289
279
self .tokenizer .pad_token = self .tokenizer .eos_token
290
280
291
- def stream (self , inputs , ** kwargs ):
281
+ def stream (self , input , ** kwargs ):
292
282
streamer = None
293
283
generation_kwargs = None
294
284
# Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it
295
- if self .conversational :
285
+ if self .task == " conversational" :
296
286
streamer = TextIteratorStreamer (
297
287
self .tokenizer , skip_prompt = True , skip_special_tokens = True
298
288
)
299
- templated_inputs = []
300
- for input in inputs :
301
- templated_inputs .append (
302
- self .tokenizer .apply_chat_template (
303
- input , add_generation_prompt = True , tokenize = False
304
- )
305
- )
306
- inputs = self .tokenizer (
307
- templated_inputs , return_tensors = "pt" , padding = True
308
- ).to (self .model .device )
309
- generation_kwargs = dict (inputs , streamer = streamer , ** kwargs )
289
+ input = self .tokenizer .apply_chat_template (
290
+ input , add_generation_prompt = True , tokenize = False
291
+ )
292
+ input = self .tokenizer (input , return_tensors = "pt" ).to (self .model .device )
293
+ generation_kwargs = dict (input , streamer = streamer , ** kwargs )
310
294
else :
311
295
streamer = TextIteratorStreamer (self .tokenizer , skip_special_tokens = True )
312
- inputs = self .tokenizer (inputs , return_tensors = "pt" , padding = True ).to (
296
+ input = self .tokenizer (input , return_tensors = "pt" , padding = True ).to (
313
297
self .model .device
314
298
)
315
- generation_kwargs = dict (inputs , streamer = streamer , ** kwargs )
316
- print ("\n \n " , file = sys .stderr )
317
- print (inputs , file = sys .stderr )
318
- print ("\n \n " , file = sys .stderr )
299
+ generation_kwargs = dict (input , streamer = streamer , ** kwargs )
319
300
thread = Thread (target = self .model .generate , kwargs = generation_kwargs )
320
301
thread .start ()
321
302
return streamer
322
303
323
304
def __call__ (self , inputs , ** kwargs ):
324
- if self .conversational :
325
- templated_inputs = []
326
- for input in inputs :
327
- templated_inputs .append (
328
- self .tokenizer .apply_chat_template (
329
- input , add_generation_prompt = True , tokenize = False
330
- )
331
- )
332
- return self .pipe (templated_inputs , return_full_text = False , ** kwargs )
305
+ if self .task == "conversational" :
306
+ inputs = self .tokenizer .apply_chat_template (
307
+ inputs , add_generation_prompt = True , tokenize = False
308
+ )
309
+ inputs = self .tokenizer (inputs , return_tensors = "pt" ).to (self .model .device )
310
+ args = dict (inputs , ** kwargs )
311
+ outputs = self .model .generate (** args )
312
+ # We only want the new ouputs for conversational pipelines
313
+ outputs = outputs [:, inputs ["input_ids" ].shape [1 ] :]
314
+ outputs = self .tokenizer .batch_decode (outputs , skip_special_tokens = True )
315
+ return outputs
316
+
317
+ # I don't think conversations support num_responses and/or maybe num_beams
318
+ # Also this is not processed in parallel / truly batched it seems
319
+ # num_conversations = 1
320
+ # if "num_return_sequences" in kwargs:
321
+ # num_conversations = kwargs.pop("num_return_sequences")
322
+ # conversations = [Conversation(inputs) for _ in range(0, num_conversations)]
323
+ # conversations = self.pipe(conversations, **kwargs)
324
+ # outputs = []
325
+ # for conversation in conversations:
326
+ # outputs.append(conversation.messages[-1]["content"])
327
+ # return outputs
333
328
else :
334
329
return self .pipe (inputs , ** kwargs )
335
330
0 commit comments