Skip to content

Commit be4788a

Browse files
committed
Working conversational everything
1 parent b6b7ec6 commit be4788a

File tree

2 files changed

+44
-49
lines changed

2 files changed

+44
-49
lines changed

pgml-extension/src/api.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,12 @@ pub fn transform_conversational_string(
682682
pub fn transform_stream_json(
683683
task: JsonB,
684684
args: default!(JsonB, "'{}'"),
685-
inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"),
685+
input: default!(&str, "''"),
686686
cache: default!(bool, false),
687687
) -> SetOfIterator<'static, JsonB> {
688688
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
689689
let python_iter =
690-
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
690+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input)
691691
.map_err(|e| error!("{e}"))
692692
.unwrap();
693693
SetOfIterator::new(python_iter)
@@ -699,13 +699,13 @@ pub fn transform_stream_json(
699699
pub fn transform_stream_string(
700700
task: String,
701701
args: default!(JsonB, "'{}'"),
702-
inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"),
702+
input: default!(&str, "''"),
703703
cache: default!(bool, false),
704704
) -> SetOfIterator<'static, JsonB> {
705705
let task_json = json!({ "task": task });
706706
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
707707
let python_iter =
708-
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
708+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input)
709709
.map_err(|e| error!("{e}"))
710710
.unwrap();
711711
SetOfIterator::new(python_iter)
@@ -725,7 +725,7 @@ pub fn transform_stream_conversational_json(
725725
.is_some_and(|v| v == "conversational")
726726
{
727727
error!(
728-
"JSONB inputs for transformer_stream should only be used with a conversational task"
728+
"ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task"
729729
);
730730
}
731731
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
@@ -747,7 +747,7 @@ pub fn transform_stream_conversational_string(
747747
) -> SetOfIterator<'static, JsonB> {
748748
if task != "conversational" {
749749
error!(
750-
"JSONB inputs for transformer_stream should only be used with a conversational task"
750+
"ARRAY::JSONB inputs for transformer_stream should only be used with a conversational task"
751751
);
752752
}
753753
let task_json = json!({ "task": task });

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,6 @@ def __init__(self, model_name, **kwargs):
226226
# but that is only possible when the task is passed in, since if you pass the model
227227
# to the pipeline constructor, the task will no longer be inferred from the default...
228228

229-
# We want to create a text-generation pipeline if it is a conversational task
230-
self.conversational = False
231-
if "task" in kwargs and kwargs["task"] == "conversational":
232-
self.conversational = True
233-
kwargs["task"] = "text-generation"
234-
235-
# Tokens can either be left or right padded depending on the architecture
236-
padding_side = "right"
237-
if "padding_side" in kwargs:
238-
padding_side = kwargs.pop("padding_side")
239-
240229
if (
241230
"task" in kwargs
242231
and model_name is not None
@@ -246,7 +235,8 @@ def __init__(self, model_name, **kwargs):
246235
"question-answering",
247236
"summarization",
248237
"translation",
249-
"text-generation"
238+
"text-generation",
239+
"conversational",
250240
]
251241
):
252242
self.task = kwargs.pop("task")
@@ -261,25 +251,25 @@ def __init__(self, model_name, **kwargs):
261251
)
262252
elif self.task == "summarization" or self.task == "translation":
263253
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
264-
elif self.task == "text-generation":
254+
elif self.task == "text-generation" or self.task == "conversational":
265255
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
266256
else:
267257
raise PgMLException(f"Unhandled task: {self.task}")
268258

269259
if "use_auth_token" in kwargs:
270260
self.tokenizer = AutoTokenizer.from_pretrained(
271-
model_name, use_auth_token=kwargs["use_auth_token"], padding_side=padding_side
261+
model_name, use_auth_token=kwargs["use_auth_token"]
272262
)
273263
else:
274-
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)
264+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
275265

276266
self.pipe = transformers.pipeline(
277267
self.task,
278268
model=self.model,
279269
tokenizer=self.tokenizer,
280270
)
281271
else:
282-
self.pipe = transformers.pipeline(**kwargs, padding_side=padding_side)
272+
self.pipe = transformers.pipeline(**kwargs)
283273
self.tokenizer = self.pipe.tokenizer
284274
self.task = self.pipe.task
285275
self.model = self.pipe.model
@@ -288,48 +278,53 @@ def __init__(self, model_name, **kwargs):
288278
if self.tokenizer.pad_token is None:
289279
self.tokenizer.pad_token = self.tokenizer.eos_token
290280

291-
def stream(self, inputs, **kwargs):
281+
def stream(self, input, **kwargs):
292282
streamer = None
293283
generation_kwargs = None
294284
# Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it
295-
if self.conversational:
285+
if self.task == "conversational":
296286
streamer = TextIteratorStreamer(
297287
self.tokenizer, skip_prompt=True, skip_special_tokens=True
298288
)
299-
templated_inputs = []
300-
for input in inputs:
301-
templated_inputs.append(
302-
self.tokenizer.apply_chat_template(
303-
input, add_generation_prompt=True, tokenize=False
304-
)
305-
)
306-
inputs = self.tokenizer(
307-
templated_inputs, return_tensors="pt", padding=True
308-
).to(self.model.device)
309-
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
289+
input = self.tokenizer.apply_chat_template(
290+
input, add_generation_prompt=True, tokenize=False
291+
)
292+
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
293+
generation_kwargs = dict(input, streamer=streamer, **kwargs)
310294
else:
311295
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
312-
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(
296+
input = self.tokenizer(input, return_tensors="pt", padding=True).to(
313297
self.model.device
314298
)
315-
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
316-
print("\n\n", file=sys.stderr)
317-
print(inputs, file=sys.stderr)
318-
print("\n\n", file=sys.stderr)
299+
generation_kwargs = dict(input, streamer=streamer, **kwargs)
319300
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
320301
thread.start()
321302
return streamer
322303

323304
def __call__(self, inputs, **kwargs):
324-
if self.conversational:
325-
templated_inputs = []
326-
for input in inputs:
327-
templated_inputs.append(
328-
self.tokenizer.apply_chat_template(
329-
input, add_generation_prompt=True, tokenize=False
330-
)
331-
)
332-
return self.pipe(templated_inputs, return_full_text=False, **kwargs)
305+
if self.task == "conversational":
306+
inputs = self.tokenizer.apply_chat_template(
307+
inputs, add_generation_prompt=True, tokenize=False
308+
)
309+
inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
310+
args = dict(inputs, **kwargs)
311+
outputs = self.model.generate(**args)
312+
# We only want the new ouputs for conversational pipelines
313+
outputs = outputs[:, inputs["input_ids"].shape[1] :]
314+
outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
315+
return outputs
316+
317+
# I don't think conversations support num_responses and/or maybe num_beams
318+
# Also this is not processed in parallel / truly batched it seems
319+
# num_conversations = 1
320+
# if "num_return_sequences" in kwargs:
321+
# num_conversations = kwargs.pop("num_return_sequences")
322+
# conversations = [Conversation(inputs) for _ in range(0, num_conversations)]
323+
# conversations = self.pipe(conversations, **kwargs)
324+
# outputs = []
325+
# for conversation in conversations:
326+
# outputs.append(conversation.messages[-1]["content"])
327+
# return outputs
333328
else:
334329
return self.pipe(inputs, **kwargs)
335330

0 commit comments

Comments
 (0)