Skip to content

Commit b6b7ec6

Browse files
committed
Not great, pivoting to better solution after talking with Santi
1 parent 1ca5dc8 commit b6b7ec6

File tree

3 files changed

+103
-63
lines changed

3 files changed

+103
-63
lines changed

pgml-extension/src/api.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -682,17 +682,14 @@ pub fn transform_conversational_string(
682682
pub fn transform_stream_json(
683683
task: JsonB,
684684
args: default!(JsonB, "'{}'"),
685-
input: default!(&str, "''"),
685+
inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"),
686686
cache: default!(bool, false),
687-
) -> SetOfIterator<'static, String> {
687+
) -> SetOfIterator<'static, JsonB> {
688688
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
689-
let python_iter = crate::bindings::transformers::transform_stream_iterator(
690-
&task.0,
691-
&args.0,
692-
input.to_string(),
693-
)
694-
.map_err(|e| error!("{e}"))
695-
.unwrap();
689+
let python_iter =
690+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
691+
.map_err(|e| error!("{e}"))
692+
.unwrap();
696693
SetOfIterator::new(python_iter)
697694
}
698695

@@ -702,13 +699,13 @@ pub fn transform_stream_json(
702699
pub fn transform_stream_string(
703700
task: String,
704701
args: default!(JsonB, "'{}'"),
705-
input: default!(&str, "''"),
702+
inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"),
706703
cache: default!(bool, false),
707-
) -> SetOfIterator<'static, String> {
704+
) -> SetOfIterator<'static, JsonB> {
708705
let task_json = json!({ "task": task });
709706
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
710707
let python_iter =
711-
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input)
708+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
712709
.map_err(|e| error!("{e}"))
713710
.unwrap();
714711
SetOfIterator::new(python_iter)
@@ -720,9 +717,9 @@ pub fn transform_stream_string(
720717
pub fn transform_stream_conversational_json(
721718
task: JsonB,
722719
args: default!(JsonB, "'{}'"),
723-
input: default!(JsonB, "'[]'::JSONB"),
720+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
724721
cache: default!(bool, false),
725-
) -> SetOfIterator<'static, String> {
722+
) -> SetOfIterator<'static, JsonB> {
726723
if !task.0["task"]
727724
.as_str()
728725
.is_some_and(|v| v == "conversational")
@@ -733,7 +730,7 @@ pub fn transform_stream_conversational_json(
733730
}
734731
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
735732
let python_iter =
736-
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input.0)
733+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs)
737734
.map_err(|e| error!("{e}"))
738735
.unwrap();
739736
SetOfIterator::new(python_iter)
@@ -745,9 +742,9 @@ pub fn transform_stream_conversational_json(
745742
pub fn transform_stream_conversational_string(
746743
task: String,
747744
args: default!(JsonB, "'{}'"),
748-
input: default!(JsonB, "'[]'::JSONB"),
745+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
749746
cache: default!(bool, false),
750-
) -> SetOfIterator<'static, String> {
747+
) -> SetOfIterator<'static, JsonB> {
751748
if task != "conversational" {
752749
error!(
753750
"JSONB inputs for transformer_stream should only be used with a conversational task"
@@ -756,7 +753,7 @@ pub fn transform_stream_conversational_string(
756753
let task_json = json!({ "task": task });
757754
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
758755
let python_iter =
759-
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input.0)
756+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs)
760757
.map_err(|e| error!("{e}"))
761758
.unwrap();
762759
SetOfIterator::new(python_iter)

pgml-extension/src/bindings/transformers/transform.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use anyhow::Result;
44
use pgrx::*;
55
use pyo3::prelude::*;
66
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
7-
use pyo3::AsPyPointer;
87

98
create_pymodule!("/src/bindings/transformers/transformers.py");
109

@@ -24,17 +23,17 @@ impl TransformStreamIterator {
2423
}
2524

2625
impl Iterator for TransformStreamIterator {
27-
type Item = String;
26+
type Item = JsonB;
2827
fn next(&mut self) -> Option<Self::Item> {
2928
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
30-
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
29+
Python::with_gil(|py| -> Result<Option<JsonB>, PyErr> {
3130
let code = "next(python_iter)";
3231
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
3332
if res.is_none() {
3433
Ok(None)
3534
} else {
36-
let res: String = res.extract()?;
37-
Ok(Some(res))
35+
let res: Vec<String> = res.extract()?;
36+
Ok(Some(JsonB(serde_json::to_value(res).unwrap())))
3837
}
3938
})
4039
.map_err(|e| error!("{e}"))

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
TrainingArguments,
4242
Trainer,
4343
TextStreamer,
44-
Conversation
44+
Conversation,
4545
)
4646
from threading import Thread
4747
from typing import Optional
@@ -95,24 +95,34 @@ def ensure_device(kwargs):
9595
else:
9696
kwargs["device"] = "cpu"
9797

98-
# A copy of HuggingFace's with small changes in the __next__ to not raise an exception
99-
class TextIteratorStreamer(TextStreamer):
100-
def __init__(
101-
self, tokenizer, skip_prompt = False, timeout = None, **decode_kwargs
102-
):
103-
super().__init__(tokenizer, skip_prompt, **decode_kwargs)
104-
self.text_queue = queue.Queue()
105-
self.stop_signal = None
98+
99+
# Follows BaseStreamer template from transformers library
100+
class TextIteratorStreamer:
101+
def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs):
102+
self.tokenizer = tokenizer
103+
self.skip_prompt = skip_prompt
106104
self.timeout = timeout
105+
self.decode_kwargs = decode_kwargs
106+
self.next_tokens_are_prompt = True
107+
self.stop_signal = None
108+
self.text_queue = queue.Queue()
107109

108-
def on_finalized_text(self, text: str, stream_end: bool = False):
109-
"""Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
110-
self.text_queue.put(text, timeout=self.timeout)
111-
if stream_end:
112-
self.text_queue.put(self.stop_signal, timeout=self.timeout)
110+
def put(self, value):
111+
if self.skip_prompt and self.next_tokens_are_prompt:
112+
self.next_tokens_are_prompt = False
113+
return
114+
# Can't batch this decode
115+
decoded_values = []
116+
for v in value:
117+
decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs))
118+
self.text_queue.put(decoded_values, self.timeout)
119+
120+
def end(self):
121+
self.next_tokens_are_prompt = True
122+
self.text_queue.put(self.stop_signal, self.timeout)
113123

114124
def __iter__(self):
115-
return self
125+
self
116126

117127
def __next__(self):
118128
value = self.text_queue.get(timeout=self.timeout)
@@ -215,6 +225,18 @@ def __init__(self, model_name, **kwargs):
215225
# to the model constructor, so we construct the model/tokenizer manually if possible,
216226
# but that is only possible when the task is passed in, since if you pass the model
217227
# to the pipeline constructor, the task will no longer be inferred from the default...
228+
229+
# We want to create a text-generation pipeline if it is a conversational task
230+
self.conversational = False
231+
if "task" in kwargs and kwargs["task"] == "conversational":
232+
self.conversational = True
233+
kwargs["task"] = "text-generation"
234+
235+
# Tokens can either be left or right padded depending on the architecture
236+
padding_side = "right"
237+
if "padding_side" in kwargs:
238+
padding_side = kwargs.pop("padding_side")
239+
218240
if (
219241
"task" in kwargs
220242
and model_name is not None
@@ -224,8 +246,7 @@ def __init__(self, model_name, **kwargs):
224246
"question-answering",
225247
"summarization",
226248
"translation",
227-
"text-generation",
228-
"conversational"
249+
"text-generation"
229250
]
230251
):
231252
self.task = kwargs.pop("task")
@@ -240,56 +261,75 @@ def __init__(self, model_name, **kwargs):
240261
)
241262
elif self.task == "summarization" or self.task == "translation":
242263
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
243-
elif self.task == "text-generation" or self.task == "conversational":
264+
elif self.task == "text-generation":
244265
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
245266
else:
246267
raise PgMLException(f"Unhandled task: {self.task}")
247268

248269
if "use_auth_token" in kwargs:
249270
self.tokenizer = AutoTokenizer.from_pretrained(
250-
model_name, use_auth_token=kwargs["use_auth_token"]
271+
model_name, use_auth_token=kwargs["use_auth_token"], padding_side=padding_side
251272
)
252273
else:
253-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
274+
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side)
254275

255276
self.pipe = transformers.pipeline(
256277
self.task,
257278
model=self.model,
258279
tokenizer=self.tokenizer,
259280
)
260281
else:
261-
self.pipe = transformers.pipeline(**kwargs)
282+
self.pipe = transformers.pipeline(**kwargs, padding_side=padding_side)
283+
self.tokenizer = self.pipe.tokenizer
262284
self.task = self.pipe.task
263285
self.model = self.pipe.model
264-
if self.pipe.tokenizer is None:
265-
self.pipe.tokenizer = AutoTokenizer.from_pretrained(
266-
self.model.name_or_path
267-
)
268-
self.tokenizer = self.pipe.tokenizer
286+
287+
# Make sure we set the pad token if it does not exist
288+
if self.tokenizer.pad_token is None:
289+
self.tokenizer.pad_token = self.tokenizer.eos_token
269290

270291
def stream(self, inputs, **kwargs):
271292
streamer = None
272293
generation_kwargs = None
273-
if self.task == "conversational":
274-
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
275-
inputs = tokenized_chat = self.tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
276-
generation_kwargs = dict(inputs=inputs, streamer=streamer, **kwargs)
294+
# Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it
295+
if self.conversational:
296+
streamer = TextIteratorStreamer(
297+
self.tokenizer, skip_prompt=True, skip_special_tokens=True
298+
)
299+
templated_inputs = []
300+
for input in inputs:
301+
templated_inputs.append(
302+
self.tokenizer.apply_chat_template(
303+
input, add_generation_prompt=True, tokenize=False
304+
)
305+
)
306+
inputs = self.tokenizer(
307+
templated_inputs, return_tensors="pt", padding=True
308+
).to(self.model.device)
309+
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
277310
else:
278-
streamer = TextIteratorStreamer(self.tokenizer)
279-
inputs = self.tokenizer([inputs], return_tensors="pt").to(self.model.device)
311+
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
312+
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to(
313+
self.model.device
314+
)
280315
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
316+
print("\n\n", file=sys.stderr)
317+
print(inputs, file=sys.stderr)
318+
print("\n\n", file=sys.stderr)
281319
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
282320
thread.start()
283321
return streamer
284322

285323
def __call__(self, inputs, **kwargs):
286-
if self.task == "conversational":
287-
outputs = []
288-
for conversation in inputs:
289-
conversation = Conversation(conversation)
290-
conversation = self.pipe(conversation, **kwargs)
291-
outputs.append(conversation.generated_responses[-1])
292-
return outputs
324+
if self.conversational:
325+
templated_inputs = []
326+
for input in inputs:
327+
templated_inputs.append(
328+
self.tokenizer.apply_chat_template(
329+
input, add_generation_prompt=True, tokenize=False
330+
)
331+
)
332+
return self.pipe(templated_inputs, return_full_text=False, **kwargs)
293333
else:
294334
return self.pipe(inputs, **kwargs)
295335

@@ -320,7 +360,11 @@ def create_pipeline(task):
320360
lower = None
321361
if lower and ("-ggml" in lower or "-gguf" in lower):
322362
pipe = GGMLPipeline(model_name, **task)
323-
elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama"):
363+
elif (
364+
lower
365+
and "-gptq" in lower
366+
and not (model_type == "mistral" or model_type == "llama")
367+
):
324368
pipe = GPTQPipeline(model_name, **task)
325369
else:
326370
try:

0 commit comments

Comments
 (0)