From d298b30fec1c45a77e258ed88ba8891acdbe4fcf Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 2 Nov 2023 14:41:40 -0700 Subject: [PATCH 1/7] Initial streaming working --- pgml-extension/src/api.rs | 55 +++++++++++++++++++ .../src/bindings/transformers/mod.rs | 37 +++++++++++++ .../src/bindings/transformers/transformers.py | 35 +++++++++++- 3 files changed, 126 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ad952e485..bd04d0ede 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -4,6 +4,8 @@ use std::str::FromStr; use ndarray::Zip; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyDict, PyString}; #[cfg(feature = "python")] use serde_json::json; @@ -632,6 +634,59 @@ pub fn transform_string( } } +struct TransformStreamIterator { + locals: Py, +} + +impl TransformStreamIterator { + fn new(python_iter: Py) -> Self { + let locals = Python::with_gil(|py| -> Result, PyErr> { + Ok([("python_iter", python_iter)].into_py_dict(py).into()) + }) + .map_err(|e| error!("{e}")) + .unwrap(); + Self { locals } + } +} + +impl Iterator for TransformStreamIterator { + type Item = String; + fn next(&mut self) -> Option { + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + Python::with_gil(|py| -> Result, PyErr> { + println!("Getting next token!"); + let code = "next(python_iter, 'DEFAULT_DONE_STRING1239847uuuuu')"; + let res: String = py + .eval(code, Some(self.locals.as_ref(py)), None)? + .extract()?; + println!("WE GOT A VALUE {:?}", res); + if res == "DEFAULT_DONE_STRING1239847uuuuu" { + Ok(None) + } else { + Ok(Some(res)) + } + }) + .map_err(|e| error!("{e}")) + .unwrap() + } +} + +#[pg_extern(name = "transform_stream")] +pub fn transform_stream_json( + task: JsonB, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), + _cache: default!(bool, false), +) -> SetOfIterator<'static, String> { + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); + println!("We got out of the transform call!"); + let res = TransformStreamIterator::new(python_iter); + SetOfIterator::new(res) +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index c4e262761..b0967c342 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -52,6 +52,43 @@ pub fn transform( Ok(serde_json::from_str(&results)?) } +pub fn transform_stream( + task: &serde_json::Value, + args: &serde_json::Value, + inputs: Vec<&str>, +) -> Result> { + crate::bindings::python::activate()?; + + whitelist::verify_task(task)?; + + let task = serde_json::to_string(task)?; + let args = serde_json::to_string(args)?; + let inputs = serde_json::to_string(&inputs)?; + + Python::with_gil(|py| -> Result> { + let transform: Py = get_module!(PY_MODULE) + .getattr(py, "transform") + .format_traceback(py)?; + + let output = transform + .call1( + py, + PyTuple::new( + py, + &[ + task.into_py(py), + args.into_py(py), + inputs.into_py(py), + true.into_py(py), + ], + ), + ) + .format_traceback(py)?; + + Ok(output) + }) +} + pub fn get_model_from(task: &Value) -> Result { Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 8b1d1a43d..c59dcdee8 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -38,7 +38,9 @@ PegasusTokenizer, TrainingArguments, Trainer, + TextIteratorStreamer, ) +from threading import Thread __cache_transformer_by_model_id = {} __cache_sentence_transformer_by_name = {} @@ -110,6 +112,31 @@ def __call__(self, inputs, **kwargs): return outputs +class ThreadedGeneratorIterator: + def __init__(self, output): + self.done_data = [] + self.output = output + self.done = False + + def do_work(g): + for x in g.output: + g.done_data.append(x) + g.done = True + thread = Thread(target=do_work, args=(self,)) + thread.start() + + def __iter__(self): + return self + + def __next__(self): + if len(self.done_data) > 0: + return self.done_data.pop(0) + elif self.done: + raise StopIteration + time.sleep(0.1) + return self.__next__() + + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -121,6 +148,10 @@ def __init__(self, model_name, **task): self.tokenizer = None self.task = "text-generation" + def stream(self, inputs, **kwargs): + output = self.model(inputs[0], stream=True, **kwargs) + return ThreadedGeneratorIterator(output) + def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: @@ -222,7 +253,7 @@ def transform_using(pipeline, args, inputs): return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode() -def transform(task, args, inputs): +def transform(task, args, inputs, stream=False): task = orjson.loads(task) args = orjson.loads(args) inputs = orjson.loads(inputs) @@ -238,6 +269,8 @@ def transform(task, args, inputs): inputs = [orjson.loads(input) for input in inputs] convert_eos_token(pipe.tokenizer, args) + if stream: + return pipe.stream(inputs, **args) return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() From 81594408a29b0a7a5b6aa40563eb856969592ec3 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:17:01 -0700 Subject: [PATCH 2/7] Working streaming for standard pipeline --- pgml-extension/src/api.rs | 2 +- .../src/bindings/transformers/transformers.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index bd04d0ede..a125e4f7b 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -5,7 +5,7 @@ use ndarray::Zip; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyDict, PyString}; +use pyo3::types::{IntoPyDict, PyDict}; #[cfg(feature = "python")] use serde_json::json; diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index c59dcdee8..c013ddfd5 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -189,7 +189,7 @@ def __init__(self, model_name, **kwargs): self.tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=kwargs["use_auth_token"]) else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) - + self.pipe = transformers.pipeline( self.task, model=self.model, @@ -203,6 +203,14 @@ def __init__(self, model_name, **kwargs): self.pipe.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path) self.tokenizer = self.pipe.tokenizer + def stream(self, inputs, **kwargs): + streamer = TextIteratorStreamer(self.tokenizer) + inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + def __call__(self, inputs, **kwargs): return self.pipe(inputs, **kwargs) From 57e20ce927a85b554bb97193b033f84f716e5007 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:25:54 -0700 Subject: [PATCH 3/7] Added streaming for GPTQ models --- .../src/bindings/transformers/transformers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index c013ddfd5..d8878c50d 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -103,6 +103,14 @@ def __init__(self, model_name, **task): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.task = "text-generation" + def stream(self, inputs, **kwargs): + streamer = TextIteratorStreamer(self.tokenizer) + inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: @@ -206,7 +214,7 @@ def __init__(self, model_name, **kwargs): def stream(self, inputs, **kwargs): streamer = TextIteratorStreamer(self.tokenizer) inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + generation_kwargs = dict(inputs, streamer=streamer, **kwargs) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() return streamer From cb039572dbec0eefbc7fc0b53daeed3105fe7487 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 3 Nov 2023 09:49:16 -0700 Subject: [PATCH 4/7] Cleaned up and split transformers into individual module --- pgml-extension/src/api.rs | 3 - .../src/bindings/transformers/mod.rs | 74 +----------------- .../src/bindings/transformers/transformers.rs | 77 +++++++++++++++++++ 3 files changed, 80 insertions(+), 74 deletions(-) create mode 100644 pgml-extension/src/bindings/transformers/transformers.rs diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index a125e4f7b..e882498b9 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -654,12 +654,10 @@ impl Iterator for TransformStreamIterator { fn next(&mut self) -> Option { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call Python::with_gil(|py| -> Result, PyErr> { - println!("Getting next token!"); let code = "next(python_iter, 'DEFAULT_DONE_STRING1239847uuuuu')"; let res: String = py .eval(code, Some(self.locals.as_ref(py)), None)? .extract()?; - println!("WE GOT A VALUE {:?}", res); if res == "DEFAULT_DONE_STRING1239847uuuuu" { Ok(None) } else { @@ -682,7 +680,6 @@ pub fn transform_stream_json( let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, inputs) .map_err(|e| error!("{e}")) .unwrap(); - println!("We got out of the transform call!"); let res = TransformStreamIterator::new(python_iter); SetOfIterator::new(res) } diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index b0967c342..8871c8458 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -16,78 +16,10 @@ use super::TracebackError; pub mod whitelist; -create_pymodule!("/src/bindings/transformers/transformers.py"); - -pub fn transform( - task: &serde_json::Value, - args: &serde_json::Value, - inputs: Vec<&str>, -) -> Result { - crate::bindings::python::activate()?; - - whitelist::verify_task(task)?; - - let task = serde_json::to_string(task)?; - let args = serde_json::to_string(args)?; - let inputs = serde_json::to_string(&inputs)?; - - let results = Python::with_gil(|py| -> Result { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; - - let output = transform - .call1( - py, - PyTuple::new( - py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], - ), - ) - .format_traceback(py)?; - - output.extract(py).format_traceback(py) - })?; - - Ok(serde_json::from_str(&results)?) -} +mod transformers; +pub use transformers::*; -pub fn transform_stream( - task: &serde_json::Value, - args: &serde_json::Value, - inputs: Vec<&str>, -) -> Result> { - crate::bindings::python::activate()?; - - whitelist::verify_task(task)?; - - let task = serde_json::to_string(task)?; - let args = serde_json::to_string(args)?; - let inputs = serde_json::to_string(&inputs)?; - - Python::with_gil(|py| -> Result> { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; - - let output = transform - .call1( - py, - PyTuple::new( - py, - &[ - task.into_py(py), - args.into_py(py), - inputs.into_py(py), - true.into_py(py), - ], - ), - ) - .format_traceback(py)?; - - Ok(output) - }) -} +create_pymodule!("/src/bindings/transformers/transformers.py"); pub fn get_model_from(task: &Value) -> Result { Python::with_gil(|py| -> Result { diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs new file mode 100644 index 000000000..77d967056 --- /dev/null +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -0,0 +1,77 @@ +use super::whitelist; +use super::TracebackError; +use anyhow::Result; +use pyo3::prelude::*; +use pyo3::types::PyTuple; +create_pymodule!("/src/bindings/transformers/transformers.py"); + +pub fn transform( + task: &serde_json::Value, + args: &serde_json::Value, + inputs: Vec<&str>, +) -> Result { + crate::bindings::python::activate()?; + + whitelist::verify_task(task)?; + + let task = serde_json::to_string(task)?; + let args = serde_json::to_string(args)?; + let inputs = serde_json::to_string(&inputs)?; + + let results = Python::with_gil(|py| -> Result { + let transform: Py = get_module!(PY_MODULE) + .getattr(py, "transform") + .format_traceback(py)?; + + let output = transform + .call1( + py, + PyTuple::new( + py, + &[task.into_py(py), args.into_py(py), inputs.into_py(py)], + ), + ) + .format_traceback(py)?; + + output.extract(py).format_traceback(py) + })?; + + Ok(serde_json::from_str(&results)?) +} + +pub fn transform_stream( + task: &serde_json::Value, + args: &serde_json::Value, + inputs: Vec<&str>, +) -> Result> { + crate::bindings::python::activate()?; + + whitelist::verify_task(task)?; + + let task = serde_json::to_string(task)?; + let args = serde_json::to_string(args)?; + let inputs = serde_json::to_string(&inputs)?; + + Python::with_gil(|py| -> Result> { + let transform: Py = get_module!(PY_MODULE) + .getattr(py, "transform") + .format_traceback(py)?; + + let output = transform + .call1( + py, + PyTuple::new( + py, + &[ + task.into_py(py), + args.into_py(py), + inputs.into_py(py), + true.into_py(py), + ], + ), + ) + .format_traceback(py)?; + + Ok(output) + }) +} From bd4002b0b90595d8c9935a3b5d7c8bc0e06ae706 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 6 Nov 2023 10:00:25 -0800 Subject: [PATCH 5/7] Switched to use Queues --- pgml-extension/src/api.rs | 9 ++++--- .../src/bindings/transformers/transformers.py | 24 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index e882498b9..052f7753f 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -654,13 +654,12 @@ impl Iterator for TransformStreamIterator { fn next(&mut self) -> Option { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call Python::with_gil(|py| -> Result, PyErr> { - let code = "next(python_iter, 'DEFAULT_DONE_STRING1239847uuuuu')"; - let res: String = py - .eval(code, Some(self.locals.as_ref(py)), None)? - .extract()?; - if res == "DEFAULT_DONE_STRING1239847uuuuu" { + let code = "next(python_iter)"; + let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; + if res.is_none() { Ok(None) } else { + let res: String = res.extract()?; Ok(Some(res)) } }) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index d8878c50d..3f714771d 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -2,6 +2,8 @@ import os import shutil import time +import queue +import sys import datasets from InstructorEmbedding import INSTRUCTOR @@ -125,25 +127,23 @@ def __init__(self, output): self.done_data = [] self.output = output self.done = False + self.q = queue.Queue() - def do_work(g): - for x in g.output: - g.done_data.append(x) - g.done = True - thread = Thread(target=do_work, args=(self,)) + def do_work(): + for x in self.output: + self.q.put(x) + self.done = True + thread = Thread(target=do_work) thread.start() def __iter__(self): return self def __next__(self): - if len(self.done_data) > 0: - return self.done_data.pop(0) - elif self.done: - raise StopIteration - time.sleep(0.1) - return self.__next__() - + while not self.done or not self.q.empty(): + v = self.q.get() + self.q.task_done() + return v class GGMLPipeline(object): def __init__(self, model_name, **task): From 82c0c668edbbb263dfedd884057f7ee2ddf7e498 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 8 Nov 2023 13:00:24 -0800 Subject: [PATCH 6/7] Updated to support json and string transform_string and updated transformers.py to use correct TextIteratorStream --- pgml-extension/src/api.rs | 24 ++++- .../src/bindings/transformers/transformers.py | 102 +++++++++++++----- 2 files changed, 100 insertions(+), 26 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 052f7753f..c4965bb4b 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -668,12 +668,14 @@ impl Iterator for TransformStreamIterator { } } -#[pg_extern(name = "transform_stream")] +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility pub fn transform_stream_json( task: JsonB, args: default!(JsonB, "'{}'"), inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), - _cache: default!(bool, false), + cache: default!(bool, false), ) -> SetOfIterator<'static, String> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, inputs) @@ -683,6 +685,24 @@ pub fn transform_stream_json( SetOfIterator::new(res) } +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_string( + task: String, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), + cache: default!(bool, false), +) -> SetOfIterator<'static, String> { + let task_json = json!({ "task": task }); + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); + let res = TransformStreamIterator::new(python_iter); + SetOfIterator::new(res) +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 3f714771d..640dbfb9c 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -40,9 +40,10 @@ PegasusTokenizer, TrainingArguments, Trainer, - TextIteratorStreamer, + TextStreamer, ) from threading import Thread +from typing import Optional __cache_transformer_by_model_id = {} __cache_sentence_transformer_by_name = {} @@ -63,14 +64,17 @@ "bool": torch.bool, } + class PgMLException(Exception): pass + def orjson_default(obj): if isinstance(obj, numpy.float32): return float(obj) raise TypeError + def convert_dtype(kwargs): if "torch_dtype" in kwargs: kwargs["torch_dtype"] = DTYPE_MAP[kwargs["torch_dtype"]] @@ -90,17 +94,46 @@ def ensure_device(kwargs): else: kwargs["device"] = "cpu" +# A copy of HuggingFace's with small changes in the __next__ +class TextIteratorStreamer(TextStreamer): + def __init__( + self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs + ): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.text_queue = queue.Queue() + self.stop_signal = None + self.timeout = timeout + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" + self.text_queue.put(text, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get(timeout=self.timeout) + if value != self.stop_signal: + return value + class GPTQPipeline(object): def __init__(self, model_name, **task): from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from huggingface_hub import snapshot_download + model_path = snapshot_download(model_name) quantized_config = BaseQuantizeConfig.from_pretrained(model_path) - self.model = AutoGPTQForCausalLM.from_quantized(model_path, quantized_config=quantized_config, **task) + self.model = AutoGPTQForCausalLM.from_quantized( + model_path, quantized_config=quantized_config, **task + ) if "use_fast_tokenizer" in task: - self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=task.pop("use_fast_tokenizer")) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=task.pop("use_fast_tokenizer") + ) else: self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.task = "text-generation" @@ -116,7 +149,11 @@ def stream(self, inputs, **kwargs): def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: - tokens = self.tokenizer(input, return_tensors="pt").to(self.model.device).input_ids + tokens = ( + self.tokenizer(input, return_tensors="pt") + .to(self.model.device) + .input_ids + ) token_ids = self.model.generate(input_ids=tokens, **kwargs)[0] outputs.append(self.tokenizer.decode(token_ids)) return outputs @@ -133,6 +170,7 @@ def do_work(): for x in self.output: self.q.put(x) self.done = True + thread = Thread(target=do_work) thread.start() @@ -140,11 +178,12 @@ def __iter__(self): return self def __next__(self): - while not self.done or not self.q.empty(): + if not self.done or not self.q.empty(): v = self.q.get() self.q.task_done() return v + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -152,7 +191,9 @@ def __init__(self, model_name, **task): task.pop("model") task.pop("task") task.pop("device") - self.model = ctransformers.AutoModelForCausalLM.from_pretrained(model_name, **task) + self.model = ctransformers.AutoModelForCausalLM.from_pretrained( + model_name, **task + ) self.tokenizer = None self.task = "text-generation" @@ -173,28 +214,39 @@ def __init__(self, model_name, **kwargs): # to the model constructor, so we construct the model/tokenizer manually if possible, # but that is only possible when the task is passed in, since if you pass the model # to the pipeline constructor, the task will no longer be inferred from the default... - if "task" in kwargs and model_name is not None and kwargs["task"] in [ - "text-classification", - "question-answering", - "summarization", - "translation", - "text-generation" - ]: + if ( + "task" in kwargs + and model_name is not None + and kwargs["task"] + in [ + "text-classification", + "question-answering", + "summarization", + "translation", + "text-generation", + ] + ): self.task = kwargs.pop("task") kwargs.pop("model", None) if self.task == "text-classification": - self.model = AutoModelForSequenceClassification.from_pretrained(model_name, **kwargs) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_name, **kwargs + ) elif self.task == "question-answering": - self.model = AutoModelForQuestionAnswering.from_pretrained(model_name, **kwargs) + self.model = AutoModelForQuestionAnswering.from_pretrained( + model_name, **kwargs + ) elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) elif self.task == "text-generation": self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") - + if "use_auth_token" in kwargs: - self.tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=kwargs["use_auth_token"]) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, use_auth_token=kwargs["use_auth_token"] + ) else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -208,7 +260,9 @@ def __init__(self, model_name, **kwargs): self.task = self.pipe.task self.model = self.pipe.model if self.pipe.tokenizer is None: - self.pipe.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path) + self.pipe.tokenizer = AutoTokenizer.from_pretrained( + self.model.name_or_path + ) self.tokenizer = self.pipe.tokenizer def stream(self, inputs, **kwargs): @@ -227,7 +281,7 @@ def get_model_from(task): task = orjson.loads(task) if "model" in task: return task["model"] - + if "task" in task: model = transformers.pipelines.SUPPORTED_TASKS[task["task"]]["default"]["model"] ty = "tf" if "tf" in model else "pt" @@ -292,7 +346,7 @@ def transform(task, args, inputs, stream=False): def create_embedding(transformer): instructor = transformer.startswith("hkunlp/instructor") - klass = INSTRUCTOR if instructor else SentenceTransformer + klass = INSTRUCTOR if instructor else SentenceTransformer return klass(transformer) @@ -306,7 +360,7 @@ def embed_using(model, transformer, inputs, kwargs): instruction = kwargs.pop("instruction") for text in inputs: texts_with_instructions.append([instruction, text]) - + inputs = texts_with_instructions return model.encode(inputs, **kwargs) @@ -318,7 +372,9 @@ def embed(transformer, inputs, kwargs): ensure_device(kwargs) if transformer not in __cache_sentence_transformer_by_name: - __cache_sentence_transformer_by_name[transformer] = create_embedding(transformer) + __cache_sentence_transformer_by_name[transformer] = create_embedding( + transformer + ) model = __cache_sentence_transformer_by_name[transformer] return embed_using(model, transformer, inputs, kwargs) @@ -783,5 +839,3 @@ def generate(model_id, data, config): ) all_preds.extend(decoded_preds) return all_preds - - From c8e066ba87616714103d959a1731dcb47fb9e3b9 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:52:18 -0800 Subject: [PATCH 7/7] Updated signature of transform_stream call --- pgml-extension/src/api.rs | 8 ++++---- .../src/bindings/transformers/transformers.py | 10 +++++----- .../src/bindings/transformers/transformers.rs | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index c4965bb4b..5b8ddc4e7 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -674,11 +674,11 @@ impl Iterator for TransformStreamIterator { pub fn transform_stream_json( task: JsonB, args: default!(JsonB, "'{}'"), - inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), + input: default!(&str, "''"), cache: default!(bool, false), ) -> SetOfIterator<'static, String> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, inputs) + let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input) .map_err(|e| error!("{e}")) .unwrap(); let res = TransformStreamIterator::new(python_iter); @@ -691,12 +691,12 @@ pub fn transform_stream_json( pub fn transform_stream_string( task: String, args: default!(JsonB, "'{}'"), - inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), + input: default!(&str, "''"), cache: default!(bool, false), ) -> SetOfIterator<'static, String> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, inputs) + let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input) .map_err(|e| error!("{e}")) .unwrap(); let res = TransformStreamIterator::new(python_iter); diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 640dbfb9c..2117cb9f6 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -94,10 +94,10 @@ def ensure_device(kwargs): else: kwargs["device"] = "cpu" -# A copy of HuggingFace's with small changes in the __next__ +# A copy of HuggingFace's with small changes in the __next__ to not raise an exception class TextIteratorStreamer(TextStreamer): def __init__( - self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs + self, tokenizer, skip_prompt = False, timeout = None, **decode_kwargs ): super().__init__(tokenizer, skip_prompt, **decode_kwargs) self.text_queue = queue.Queue() @@ -160,11 +160,11 @@ def __call__(self, inputs, **kwargs): class ThreadedGeneratorIterator: - def __init__(self, output): - self.done_data = [] + def __init__(self, output, starting_input): self.output = output self.done = False self.q = queue.Queue() + self.q.put(starting_input) def do_work(): for x in self.output: @@ -199,7 +199,7 @@ def __init__(self, model_name, **task): def stream(self, inputs, **kwargs): output = self.model(inputs[0], stream=True, **kwargs) - return ThreadedGeneratorIterator(output) + return ThreadedGeneratorIterator(output, inputs[0]) def __call__(self, inputs, **kwargs): outputs = [] diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs index 77d967056..55d59b070 100644 --- a/pgml-extension/src/bindings/transformers/transformers.rs +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -42,7 +42,7 @@ pub fn transform( pub fn transform_stream( task: &serde_json::Value, args: &serde_json::Value, - inputs: Vec<&str>, + input: &str, ) -> Result> { crate::bindings::python::activate()?; @@ -50,7 +50,7 @@ pub fn transform_stream( let task = serde_json::to_string(task)?; let args = serde_json::to_string(args)?; - let inputs = serde_json::to_string(&inputs)?; + let inputs = serde_json::to_string(&vec![input])?; Python::with_gil(|py| -> Result> { let transform: Py = get_module!(PY_MODULE)