From 3ff2f07e7a4eff3511ac06f2ad8a1959860dfb30 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:06:32 -0800 Subject: [PATCH 01/22] Added OpenSourceAI and conversational support in the extension --- pgml-extension/src/api.rs | 87 +++++++- .../src/bindings/transformers/transform.rs | 17 +- .../src/bindings/transformers/transformers.py | 31 ++- pgml-sdks/pgml/build.rs | 3 +- .../javascript/tests/typescript-tests/test.ts | 24 ++- pgml-sdks/pgml/python/tests/test.py | 29 ++- pgml-sdks/pgml/src/lib.rs | 7 + pgml-sdks/pgml/src/open_source_ai.rs | 190 ++++++++++++++++++ pgml-sdks/pgml/src/transformer_pipeline.rs | 43 +++- pgml-sdks/pgml/src/types.rs | 13 ++ 10 files changed, 410 insertions(+), 34 deletions(-) create mode 100644 pgml-sdks/pgml/src/open_source_ai.rs diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ab132bc4c..de9bf51eb 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -632,6 +632,37 @@ pub fn transform_string( } } +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_conversational_json( + task: JsonB, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), + cache: default!(bool, false), +) -> JsonB { + match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { + Ok(output) => JsonB(output), + Err(e) => error!("{e}"), + } +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_conversational_string( + task: String, + args: default!(JsonB, "'{}'"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), + cache: default!(bool, false), +) -> JsonB { + let task_json = json!({ "task": task }); + match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { + Ok(output) => JsonB(output), + Err(e) => error!("{e}"), + } +} + #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "transform_stream")] #[allow(unused_variables)] // cache is maintained for api compatibility @@ -642,10 +673,13 @@ pub fn transform_stream_json( cache: default!(bool, false), ) -> SetOfIterator<'static, String> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator( + &task.0, + &args.0, + input.to_string(), + ) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -667,6 +701,51 @@ pub fn transform_stream_string( SetOfIterator::new(python_iter) } +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_conversational_json( + task: JsonB, + args: default!(JsonB, "'{}'"), + input: default!(JsonB, "'[]'::JSONB"), + cache: default!(bool, false), +) -> SetOfIterator<'static, String> { + // If they have Vec inputs lets make sure they have the write task + if !task.0["task"] + .as_str() + .is_some_and(|v| v == "conversational") + { + error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task"); + } + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input.0) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_conversational_string( + task: String, + args: default!(JsonB, "'{}'"), + input: default!(JsonB, "'[]'::JSONB"), + cache: default!(bool, false), +) -> SetOfIterator<'static, String> { + if task != "conversational" { + error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task"); + } + let task_json = json!({ "task": task }); + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input.0) + .map_err(|e| error!("{e}")) + .unwrap(); + SetOfIterator::new(python_iter) +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index a03c0d751..81aa1d77d 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -4,6 +4,7 @@ use anyhow::Result; use pgrx::*; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; +use pyo3::AsPyPointer; create_pymodule!("/src/bindings/transformers/transformers.py"); @@ -41,10 +42,10 @@ impl Iterator for TransformStreamIterator { } } -pub fn transform( +pub fn transform( task: &serde_json::Value, args: &serde_json::Value, - inputs: Vec<&str>, + inputs: T, ) -> Result { crate::bindings::python::activate()?; whitelist::verify_task(task)?; @@ -74,17 +75,17 @@ pub fn transform( Ok(serde_json::from_str(&results)?) } -pub fn transform_stream( +pub fn transform_stream( task: &serde_json::Value, args: &serde_json::Value, - input: &str, + input: T, ) -> Result> { crate::bindings::python::activate()?; whitelist::verify_task(task)?; let task = serde_json::to_string(task)?; let args = serde_json::to_string(args)?; - let inputs = serde_json::to_string(&vec![input])?; + let input = serde_json::to_string(&input)?; Python::with_gil(|py| -> Result> { let transform: Py = get_module!(PY_MODULE) @@ -99,7 +100,7 @@ pub fn transform_stream( &[ task.into_py(py), args.into_py(py), - inputs.into_py(py), + input.into_py(py), true.into_py(py), ], ), @@ -110,10 +111,10 @@ pub fn transform_stream( }) } -pub fn transform_stream_iterator( +pub fn transform_stream_iterator( task: &serde_json::Value, args: &serde_json::Value, - input: &str, + input: T, ) -> Result { let python_iter = transform_stream(task, args, input) .map_err(|e| error!("{e}")) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 143f6d393..7f66d7975 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -41,6 +41,7 @@ TrainingArguments, Trainer, TextStreamer, + Conversation ) from threading import Thread from typing import Optional @@ -198,8 +199,8 @@ def __init__(self, model_name, **task): self.task = "text-generation" def stream(self, inputs, **kwargs): - output = self.model(inputs[0], stream=True, **kwargs) - return ThreadedGeneratorIterator(output, inputs[0]) + output = self.model(inputs, stream=True, **kwargs) + return ThreadedGeneratorIterator(output, inputs) def __call__(self, inputs, **kwargs): outputs = [] @@ -224,6 +225,7 @@ def __init__(self, model_name, **kwargs): "summarization", "translation", "text-generation", + "conversational" ] ): self.task = kwargs.pop("task") @@ -238,7 +240,7 @@ def __init__(self, model_name, **kwargs): ) elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) - elif self.task == "text-generation": + elif self.task == "text-generation" or self.task == "conversational": self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") @@ -266,15 +268,30 @@ def __init__(self, model_name, **kwargs): self.tokenizer = self.pipe.tokenizer def stream(self, inputs, **kwargs): - streamer = TextIteratorStreamer(self.tokenizer) - inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + streamer = None + generation_kwargs = None + if self.task == "conversational": + streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True) + inputs = tokenized_chat = self.tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(inputs=inputs, streamer=streamer, **kwargs) + else: + streamer = TextIteratorStreamer(self.tokenizer) + inputs = self.tokenizer([inputs], return_tensors="pt").to(self.model.device) + generation_kwargs = dict(inputs, streamer=streamer, **kwargs) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() return streamer def __call__(self, inputs, **kwargs): - return self.pipe(inputs, **kwargs) + if self.task == "conversational": + outputs = [] + for conversation in inputs: + conversation = Conversation(conversation) + conversation = self.pipe(conversation, **kwargs) + outputs.append(conversation.generated_responses[-1]) + return outputs + else: + return self.pipe(inputs, **kwargs) def get_model_from(task): diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 4f476884f..82b51670c 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -14,7 +14,7 @@ const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" export function init_logger(level?: string, format?: string): void; export function migrate(): Promise; -export type Json = { [key: string]: any }; +export type Json = any; export type DateTime = Date; export function newCollection(name: string, database_url?: string): Collection; @@ -23,6 +23,7 @@ export function newSplitter(name?: string, parameters?: Json): Splitter; export function newBuiltins(database_url?: string): Builtins; export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline; export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline; +export function newOpenSourceAI(database_url?: string): OpenSourceAI; "#; fn main() { diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index affb314fa..a802ef400 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -299,7 +299,29 @@ it("can transformer pipeline stream", async () => { output.push(result.value); result = await it.next(); } - expect(output.length).toBeGreaterThan(0) + expect(output.length).toBeGreaterThan(0); +}); + +/////////////////////////////////////////////////// +// Test OpenSourceAI ////////////////////////////// +/////////////////////////////////////////////////// + +it("can open source ai create", async () => { + const client = pgml.newOpenSourceAI(); + const results = client.chat_completions_create( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + expect(results.choices.length).toBeGreaterThan(0); }); /////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index 97ca155f5..fdf3725b9 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -307,7 +307,8 @@ async def test_order_documents(): async def test_transformer_pipeline(): t = pgml.TransformerPipeline("text-generation") it = await t.transform(["AI is going to"], {"max_new_tokens": 5}) - assert (len(it)) > 0 + assert len(it) > 0 + @pytest.mark.asyncio async def test_transformer_pipeline_stream(): @@ -316,7 +317,31 @@ async def test_transformer_pipeline_stream(): total = [] async for c in it: total.append(c) - assert (len(total)) > 0 + assert len(total) > 0 + + +################################################### +## Transformer Pipeline Tests ##################### +################################################### + + +def test_open_source_ai_create(): + client = pgml.OpenSourceAI() + results = client.chat_completions_create( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85 + ) + assert len(results["choices"]) > 0 ################################################### diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index cd0eaaeef..b115da69c 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -19,6 +19,7 @@ mod languages; pub mod migrations; mod model; pub mod models; +mod open_source_ai; mod order_by_builder; mod pipeline; mod queries; @@ -34,6 +35,7 @@ mod utils; pub use builtins::Builtins; pub use collection::Collection; pub use model::Model; +pub use open_source_ai::OpenSourceAI; pub use pipeline::Pipeline; pub use splitter::Splitter; pub use transformer_pipeline::TransformerPipeline; @@ -152,6 +154,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } @@ -201,6 +204,10 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> { transformer_pipeline::TransformerPipelineJavascript::new, )?; cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?; + cx.export_function( + "newOpenSourceAI", + open_source_ai::OpenSourceAIJavascript::new, + )?; Ok(()) } diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs new file mode 100644 index 000000000..408797ef2 --- /dev/null +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -0,0 +1,190 @@ +use anyhow::Context; +use rust_bridge::{alias, alias_methods}; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; + +use crate::{types::Json, TransformerPipeline}; + +#[cfg(feature = "python")] +use crate::types::JsonPython; + +#[derive(alias, Debug, Clone)] +pub struct OpenSourceAI { + database_url: Option, +} + +fn try_model_nice_name_to_model_name_and_parameters( + model_name: &str, +) -> Option<(&'static str, Json)> { + match model_name { + "mistralai/Mistral-7B-v0.1" => Some(( + "TheBloke/zephyr-7B-beta-GPTQ", + serde_json::json!({ + "task": "conversational", + "model": "TheBloke/zephyr-7B-beta-GPTQ", + "device_map": "auto", + "revision": "main", + "model_type": "mistral" + }) + .into(), + )), + "Llama-2-7b-chat-hf" => Some(( + "TheBloke/Llama-2-7B-Chat-GPTQ", + serde_json::json!({ + "task": "conversational", + "model": "TheBloke/zephyr-7B-beta-GPTQ", + "device_map": "auto", + "revision": "main", + "model_type": "llama" + }) + .into(), + )), + _ => None, + } +} + +#[alias_methods(new, chat_completions_create, chat_completions_create_async)] +impl OpenSourceAI { + pub fn new(database_url: Option) -> Self { + Self { database_url } + } + + pub async fn chat_completions_create_async( + &self, + mut model: Json, + messages: Json, + max_tokens: Option, + temperature: Option, + n: Option, + ) -> anyhow::Result { + let (transformer_pipeline, model_name, model_parameters) = if model.is_object() { + let args = model.as_object_mut().unwrap(); + let model_name = args + .remove("model") + .context("`model` is a required key in the model object")?; + let model_name = model_name.as_str().context("`model` must be a string")?; + ( + TransformerPipeline::new( + "conversational", + Some(model_name.to_string()), + Some(model.clone()), + self.database_url.clone(), + ), + model_name.to_string(), + model, + ) + } else { + let model_name = model + .as_str() + .context("`model` must either be a string or an object")?; + let (real_model_name, parameters) = + try_model_nice_name_to_model_name_and_parameters(model_name).context( + r#"Please select one of the provided models: +mistralai/Mistral-7B-v0.1 +"#, + )?; + ( + TransformerPipeline::new( + "conversational", + Some(real_model_name.to_string()), + Some(parameters.clone()), + self.database_url.clone(), + ), + model_name.to_string(), + parameters, + ) + }; + + let max_tokens = max_tokens.unwrap_or(1000); + let temperature = temperature.unwrap_or(0.8); + let n = n.unwrap_or(1) as usize; + let to_hash = format!( + "{}{}{}{}", + model_parameters.to_string(), + max_tokens, + temperature, + n + ); + let md5_digest = md5::compute(to_hash.as_bytes()); + let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; + + let messages: Vec = std::iter::repeat(messages).take(n).collect(); + let choices = transformer_pipeline + .transform( + messages, + Some( + serde_json::json!({ "max_length": max_tokens, "temperature": temperature }) + .into(), + ), + ) + .await?; + let choices: Vec = choices + .as_array() + .context("Error parsing return from TransformerPipeline")? + .into_iter() + .enumerate() + .map(|(i, c)| { + serde_json::json!({ + "index": i, + "message": { + "role": "assistant", + "content": c + } + // Finish reason should be here + }) + .into() + }) + .collect(); + let since_the_epoch = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + Ok(serde_json::json!({ + "id": Uuid::new_v4().to_string(), + "object": "chat.completion", + "created": since_the_epoch.as_secs(), + "model": model_name, + "system_fingerprint": fingerprint, + "choices": choices, + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + }) + .into()) + } + + pub fn chat_completions_create( + &self, + model: Json, + messages: Json, + max_tokens: Option, + temperature: Option, + n: Option, + ) -> anyhow::Result { + let runtime = crate::get_or_set_runtime(); + runtime.block_on(self.chat_completions_create_async( + model, + messages, + max_tokens, + temperature, + n, + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[sqlx::test] + async fn can_open_source_ai_create() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), serde_json::json!([ + {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}, + {"role": "user", "content": "How many helicopters can a human eat in one sitting?"} + ]).into(), Some(1000), None, None)?; + assert!(results.as_array().is_some()); + Ok(()) + } +} diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 70fd3f925..1ec11808c 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use futures::Stream; use rust_bridge::{alias, alias_manual, alias_methods}; use sqlx::{postgres::PgRow, Row}; @@ -141,16 +142,36 @@ impl TransformerPipeline { } #[instrument(skip(self))] - pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { + pub async fn transform(&self, inputs: Vec, args: Option) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let args = args.unwrap_or_default(); - let results = sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)") - .bind(&self.task) - .bind(inputs) - .bind(&args) - .fetch_all(&pool) - .await?; + // We set the task in the new constructor so we can unwrap here + let results = if self.task["task"].as_str().unwrap() == "conversational" { + let inputs: Vec = inputs.into_iter().map(|j| j.0).collect(); + sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)") + .bind(&self.task) + .bind(inputs) + .bind(&args) + .fetch_all(&pool) + .await? + } else { + let inputs: anyhow::Result> = + inputs + .into_iter() + .map(|input| { + input.as_str().context( + "the inputs arg must be strings when not using the conversational task", + ).map(|s| s.to_string()) + }) + .collect(); + sqlx::query("SELECT pgml.transform(task => $1, inputs => $2, args => $3)") + .bind(&self.task) + .bind(inputs?) + .bind(&args) + .fetch_all(&pool) + .await? + }; let results = results.get(0).unwrap().get::(0); Ok(Json(results)) } @@ -198,8 +219,8 @@ mod tests { let results = t .transform( vec![ - "How are you doing today?".to_string(), - "What is a good song?".to_string(), + serde_json::Value::String("How are you doing today?".to_string()).into(), + serde_json::Value::String("How are you doing today?".to_string()).into(), ], None, ) @@ -215,8 +236,8 @@ mod tests { let results = t .transform( vec![ - "How are you doing today?".to_string(), - "What is a good song?".to_string(), + serde_json::Value::String("How are you doing today?".to_string()).into(), + serde_json::Value::String("How are you doing today?".to_string()).into(), ], None, ) diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index ba80583e8..e69e2f42d 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -42,6 +42,19 @@ impl Serialize for Json { } } +// This will cause some conflicting trait issue +// impl From for Json { +// fn from(v: T) -> Self { +// Self(serde_json::to_value(v).unwrap()) +// } +// } + +impl Json { + pub fn from_serializable(v: T) -> Self { + Self(serde_json::to_value(v).unwrap()) + } +} + pub(crate) trait TryToNumeric { fn try_to_u64(&self) -> anyhow::Result; fn try_to_i64(&self) -> anyhow::Result { From 1ca5dc8ea134676cd61dac48382db35337feea0e Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:21:59 -0800 Subject: [PATCH 02/22] Clean up errors and guard rails around conversational api --- pgml-extension/src/api.rs | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index de9bf51eb..439182f2d 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -641,6 +641,14 @@ pub fn transform_conversational_json( inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), ) -> JsonB { + if !task.0["task"] + .as_str() + .is_some_and(|v| v == "conversational") + { + error!( + "ARRAY[]::JSONB inputs for transformer should only be used with a conversational task" + ); + } match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), @@ -656,6 +664,11 @@ pub fn transform_conversational_string( inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), ) -> JsonB { + if task != "conversational" { + error!( + "ARRAY[]::JSONB inputs for transformer should only be used with a conversational task" + ); + } let task_json = json!({ "task": task }); match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { Ok(output) => JsonB(output), @@ -710,12 +723,13 @@ pub fn transform_stream_conversational_json( input: default!(JsonB, "'[]'::JSONB"), cache: default!(bool, false), ) -> SetOfIterator<'static, String> { - // If they have Vec inputs lets make sure they have the write task if !task.0["task"] .as_str() .is_some_and(|v| v == "conversational") { - error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task"); + error!( + "JSONB inputs for transformer_stream should only be used with a conversational task" + ); } // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = @@ -735,7 +749,9 @@ pub fn transform_stream_conversational_string( cache: default!(bool, false), ) -> SetOfIterator<'static, String> { if task != "conversational" { - error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task"); + error!( + "JSONB inputs for transformer_stream should only be used with a conversational task" + ); } let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call From b6b7ec600bc86f7e7a38c4fadf35423f06263b7e Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 28 Nov 2023 12:54:01 -0800 Subject: [PATCH 03/22] Not great, pivoting to better solution after talking with Santi --- pgml-extension/src/api.rs | 33 +++-- .../src/bindings/transformers/transform.rs | 9 +- .../src/bindings/transformers/transformers.py | 124 ++++++++++++------ 3 files changed, 103 insertions(+), 63 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 439182f2d..16910b091 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -682,17 +682,14 @@ pub fn transform_conversational_string( pub fn transform_stream_json( task: JsonB, args: default!(JsonB, "'{}'"), - input: default!(&str, "''"), + inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), cache: default!(bool, false), -) -> SetOfIterator<'static, String> { +) -> SetOfIterator<'static, JsonB> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = crate::bindings::transformers::transform_stream_iterator( - &task.0, - &args.0, - input.to_string(), - ) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -702,13 +699,13 @@ pub fn transform_stream_json( pub fn transform_stream_string( task: String, args: default!(JsonB, "'{}'"), - input: default!(&str, "''"), + inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), cache: default!(bool, false), -) -> SetOfIterator<'static, String> { +) -> SetOfIterator<'static, JsonB> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) .map_err(|e| error!("{e}")) .unwrap(); SetOfIterator::new(python_iter) @@ -720,9 +717,9 @@ pub fn transform_stream_string( pub fn transform_stream_conversational_json( task: JsonB, args: default!(JsonB, "'{}'"), - input: default!(JsonB, "'[]'::JSONB"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), -) -> SetOfIterator<'static, String> { +) -> SetOfIterator<'static, JsonB> { if !task.0["task"] .as_str() .is_some_and(|v| v == "conversational") @@ -733,7 +730,7 @@ pub fn transform_stream_conversational_json( } // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input.0) + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) .map_err(|e| error!("{e}")) .unwrap(); SetOfIterator::new(python_iter) @@ -745,9 +742,9 @@ pub fn transform_stream_conversational_json( pub fn transform_stream_conversational_string( task: String, args: default!(JsonB, "'{}'"), - input: default!(JsonB, "'[]'::JSONB"), + inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), -) -> SetOfIterator<'static, String> { +) -> SetOfIterator<'static, JsonB> { if task != "conversational" { error!( "JSONB inputs for transformer_stream should only be used with a conversational task" @@ -756,7 +753,7 @@ pub fn transform_stream_conversational_string( let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input.0) + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) .map_err(|e| error!("{e}")) .unwrap(); SetOfIterator::new(python_iter) diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index 81aa1d77d..fa03984d9 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -4,7 +4,6 @@ use anyhow::Result; use pgrx::*; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyDict, PyTuple}; -use pyo3::AsPyPointer; create_pymodule!("/src/bindings/transformers/transformers.py"); @@ -24,17 +23,17 @@ impl TransformStreamIterator { } impl Iterator for TransformStreamIterator { - type Item = String; + type Item = JsonB; fn next(&mut self) -> Option { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - Python::with_gil(|py| -> Result, PyErr> { + Python::with_gil(|py| -> Result, PyErr> { let code = "next(python_iter)"; let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; if res.is_none() { Ok(None) } else { - let res: String = res.extract()?; - Ok(Some(res)) + let res: Vec = res.extract()?; + Ok(Some(JsonB(serde_json::to_value(res).unwrap()))) } }) .map_err(|e| error!("{e}")) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 7f66d7975..d6ee1ae2f 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -41,7 +41,7 @@ TrainingArguments, Trainer, TextStreamer, - Conversation + Conversation, ) from threading import Thread from typing import Optional @@ -95,24 +95,34 @@ def ensure_device(kwargs): else: kwargs["device"] = "cpu" -# A copy of HuggingFace's with small changes in the __next__ to not raise an exception -class TextIteratorStreamer(TextStreamer): - def __init__( - self, tokenizer, skip_prompt = False, timeout = None, **decode_kwargs - ): - super().__init__(tokenizer, skip_prompt, **decode_kwargs) - self.text_queue = queue.Queue() - self.stop_signal = None + +# Follows BaseStreamer template from transformers library +class TextIteratorStreamer: + def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt self.timeout = timeout + self.decode_kwargs = decode_kwargs + self.next_tokens_are_prompt = True + self.stop_signal = None + self.text_queue = queue.Queue() - def on_finalized_text(self, text: str, stream_end: bool = False): - """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.text_queue.put(text, timeout=self.timeout) - if stream_end: - self.text_queue.put(self.stop_signal, timeout=self.timeout) + def put(self, value): + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + # Can't batch this decode + decoded_values = [] + for v in value: + decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs)) + self.text_queue.put(decoded_values, self.timeout) + + def end(self): + self.next_tokens_are_prompt = True + self.text_queue.put(self.stop_signal, self.timeout) def __iter__(self): - return self + self def __next__(self): value = self.text_queue.get(timeout=self.timeout) @@ -215,6 +225,18 @@ def __init__(self, model_name, **kwargs): # to the model constructor, so we construct the model/tokenizer manually if possible, # but that is only possible when the task is passed in, since if you pass the model # to the pipeline constructor, the task will no longer be inferred from the default... + + # We want to create a text-generation pipeline if it is a conversational task + self.conversational = False + if "task" in kwargs and kwargs["task"] == "conversational": + self.conversational = True + kwargs["task"] = "text-generation" + + # Tokens can either be left or right padded depending on the architecture + padding_side = "right" + if "padding_side" in kwargs: + padding_side = kwargs.pop("padding_side") + if ( "task" in kwargs and model_name is not None @@ -224,8 +246,7 @@ def __init__(self, model_name, **kwargs): "question-answering", "summarization", "translation", - "text-generation", - "conversational" + "text-generation" ] ): self.task = kwargs.pop("task") @@ -240,17 +261,17 @@ def __init__(self, model_name, **kwargs): ) elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) - elif self.task == "text-generation" or self.task == "conversational": + elif self.task == "text-generation": self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") if "use_auth_token" in kwargs: self.tokenizer = AutoTokenizer.from_pretrained( - model_name, use_auth_token=kwargs["use_auth_token"] + model_name, use_auth_token=kwargs["use_auth_token"], padding_side=padding_side ) else: - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side) self.pipe = transformers.pipeline( self.task, @@ -258,38 +279,57 @@ def __init__(self, model_name, **kwargs): tokenizer=self.tokenizer, ) else: - self.pipe = transformers.pipeline(**kwargs) + self.pipe = transformers.pipeline(**kwargs, padding_side=padding_side) + self.tokenizer = self.pipe.tokenizer self.task = self.pipe.task self.model = self.pipe.model - if self.pipe.tokenizer is None: - self.pipe.tokenizer = AutoTokenizer.from_pretrained( - self.model.name_or_path - ) - self.tokenizer = self.pipe.tokenizer + + # Make sure we set the pad token if it does not exist + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token def stream(self, inputs, **kwargs): streamer = None generation_kwargs = None - if self.task == "conversational": - streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True) - inputs = tokenized_chat = self.tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(inputs=inputs, streamer=streamer, **kwargs) + # Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it + if self.conversational: + streamer = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True + ) + templated_inputs = [] + for input in inputs: + templated_inputs.append( + self.tokenizer.apply_chat_template( + input, add_generation_prompt=True, tokenize=False + ) + ) + inputs = self.tokenizer( + templated_inputs, return_tensors="pt", padding=True + ).to(self.model.device) + generation_kwargs = dict(inputs, streamer=streamer, **kwargs) else: - streamer = TextIteratorStreamer(self.tokenizer) - inputs = self.tokenizer([inputs], return_tensors="pt").to(self.model.device) + streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) + inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to( + self.model.device + ) generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + print("\n\n", file=sys.stderr) + print(inputs, file=sys.stderr) + print("\n\n", file=sys.stderr) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() return streamer def __call__(self, inputs, **kwargs): - if self.task == "conversational": - outputs = [] - for conversation in inputs: - conversation = Conversation(conversation) - conversation = self.pipe(conversation, **kwargs) - outputs.append(conversation.generated_responses[-1]) - return outputs + if self.conversational: + templated_inputs = [] + for input in inputs: + templated_inputs.append( + self.tokenizer.apply_chat_template( + input, add_generation_prompt=True, tokenize=False + ) + ) + return self.pipe(templated_inputs, return_full_text=False, **kwargs) else: return self.pipe(inputs, **kwargs) @@ -320,7 +360,11 @@ def create_pipeline(task): lower = None if lower and ("-ggml" in lower or "-gguf" in lower): pipe = GGMLPipeline(model_name, **task) - elif lower and "-gptq" in lower and not (model_type == "mistral" or model_type == "llama"): + elif ( + lower + and "-gptq" in lower + and not (model_type == "mistral" or model_type == "llama") + ): pipe = GPTQPipeline(model_name, **task) else: try: From be4788a473a9f31a1d015aa963fdf9879bfbe031 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 28 Nov 2023 14:31:04 -0800 Subject: [PATCH 04/22] Working conversational everything --- pgml-extension/src/api.rs | 12 +-- .../src/bindings/transformers/transformers.py | 81 +++++++++---------- 2 files changed, 44 insertions(+), 49 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 16910b091..b6ee865bf 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -682,12 +682,12 @@ pub fn transform_conversational_string( pub fn transform_stream_json( task: JsonB, args: default!(JsonB, "'{}'"), - inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), + input: default!(&str, "''"), cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) + crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) .map_err(|e| error!("{e}")) .unwrap(); SetOfIterator::new(python_iter) @@ -699,13 +699,13 @@ pub fn transform_stream_json( pub fn transform_stream_string( task: String, args: default!(JsonB, "'{}'"), - inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), + input: default!(&str, "''"), cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) + crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) .map_err(|e| error!("{e}")) .unwrap(); SetOfIterator::new(python_iter) @@ -725,7 +725,7 @@ pub fn transform_stream_conversational_json( .is_some_and(|v| v == "conversational") { error!( - "JSONB inputs for transformer_stream should only be used with a conversational task" + "ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task" ); } // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call @@ -747,7 +747,7 @@ pub fn transform_stream_conversational_string( ) -> SetOfIterator<'static, JsonB> { if task != "conversational" { error!( - "JSONB inputs for transformer_stream should only be used with a conversational task" + "ARRAY::JSONB inputs for transformer_stream should only be used with a conversational task" ); } let task_json = json!({ "task": task }); diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index d6ee1ae2f..c81bc48cd 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -226,17 +226,6 @@ def __init__(self, model_name, **kwargs): # but that is only possible when the task is passed in, since if you pass the model # to the pipeline constructor, the task will no longer be inferred from the default... - # We want to create a text-generation pipeline if it is a conversational task - self.conversational = False - if "task" in kwargs and kwargs["task"] == "conversational": - self.conversational = True - kwargs["task"] = "text-generation" - - # Tokens can either be left or right padded depending on the architecture - padding_side = "right" - if "padding_side" in kwargs: - padding_side = kwargs.pop("padding_side") - if ( "task" in kwargs and model_name is not None @@ -246,7 +235,8 @@ def __init__(self, model_name, **kwargs): "question-answering", "summarization", "translation", - "text-generation" + "text-generation", + "conversational", ] ): self.task = kwargs.pop("task") @@ -261,17 +251,17 @@ def __init__(self, model_name, **kwargs): ) elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) - elif self.task == "text-generation": + elif self.task == "text-generation" or self.task == "conversational": self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") if "use_auth_token" in kwargs: self.tokenizer = AutoTokenizer.from_pretrained( - model_name, use_auth_token=kwargs["use_auth_token"], padding_side=padding_side + model_name, use_auth_token=kwargs["use_auth_token"] ) else: - self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=padding_side) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.pipe = transformers.pipeline( self.task, @@ -279,7 +269,7 @@ def __init__(self, model_name, **kwargs): tokenizer=self.tokenizer, ) else: - self.pipe = transformers.pipeline(**kwargs, padding_side=padding_side) + self.pipe = transformers.pipeline(**kwargs) self.tokenizer = self.pipe.tokenizer self.task = self.pipe.task self.model = self.pipe.model @@ -288,48 +278,53 @@ def __init__(self, model_name, **kwargs): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token - def stream(self, inputs, **kwargs): + def stream(self, input, **kwargs): streamer = None generation_kwargs = None # Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it - if self.conversational: + if self.task == "conversational": streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) - templated_inputs = [] - for input in inputs: - templated_inputs.append( - self.tokenizer.apply_chat_template( - input, add_generation_prompt=True, tokenize=False - ) - ) - inputs = self.tokenizer( - templated_inputs, return_tensors="pt", padding=True - ).to(self.model.device) - generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + input = self.tokenizer.apply_chat_template( + input, add_generation_prompt=True, tokenize=False + ) + input = self.tokenizer(input, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(input, streamer=streamer, **kwargs) else: streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) - inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to( + input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device ) - generation_kwargs = dict(inputs, streamer=streamer, **kwargs) - print("\n\n", file=sys.stderr) - print(inputs, file=sys.stderr) - print("\n\n", file=sys.stderr) + generation_kwargs = dict(input, streamer=streamer, **kwargs) thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() return streamer def __call__(self, inputs, **kwargs): - if self.conversational: - templated_inputs = [] - for input in inputs: - templated_inputs.append( - self.tokenizer.apply_chat_template( - input, add_generation_prompt=True, tokenize=False - ) - ) - return self.pipe(templated_inputs, return_full_text=False, **kwargs) + if self.task == "conversational": + inputs = self.tokenizer.apply_chat_template( + inputs, add_generation_prompt=True, tokenize=False + ) + inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) + args = dict(inputs, **kwargs) + outputs = self.model.generate(**args) + # We only want the new ouputs for conversational pipelines + outputs = outputs[:, inputs["input_ids"].shape[1] :] + outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return outputs + + # I don't think conversations support num_responses and/or maybe num_beams + # Also this is not processed in parallel / truly batched it seems + # num_conversations = 1 + # if "num_return_sequences" in kwargs: + # num_conversations = kwargs.pop("num_return_sequences") + # conversations = [Conversation(inputs) for _ in range(0, num_conversations)] + # conversations = self.pipe(conversations, **kwargs) + # outputs = [] + # for conversation in conversations: + # outputs.append(conversation.messages[-1]["content"]) + # return outputs else: return self.pipe(inputs, **kwargs) From 719fdc5a679068eceed35077d3ddd2f5c97a3ab0 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 28 Nov 2023 14:32:39 -0800 Subject: [PATCH 05/22] Fixed typo --- pgml-extension/src/bindings/transformers/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index c81bc48cd..fbffc3c7e 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -122,7 +122,7 @@ def end(self): self.text_queue.put(self.stop_signal, self.timeout) def __iter__(self): - self + return self def __next__(self): value = self.text_queue.get(timeout=self.timeout) From 23ef2a336aeb2a8c5a436238efda7b801ae59804 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:50:55 -0800 Subject: [PATCH 06/22] Working non streaming open source ai replacement --- .../src/bindings/transformers/transform.rs | 2 + .../src/bindings/transformers/transformers.py | 1 + pgml-sdks/pgml/src/languages/javascript.rs | 6 +- pgml-sdks/pgml/src/languages/python.rs | 2 +- pgml-sdks/pgml/src/open_source_ai.rs | 72 +++++++++++-------- pgml-sdks/pgml/src/transformer_pipeline.rs | 8 +-- 6 files changed, 54 insertions(+), 37 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index fa03984d9..9bd81b77a 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -32,7 +32,9 @@ impl Iterator for TransformStreamIterator { if res.is_none() { Ok(None) } else { + eprintln!("\nHERE WE ARE!\n"); let res: Vec = res.extract()?; + eprintln!("\nYUP WE DIDNT GET HERE\n"); Ok(Some(JsonB(serde_json::to_value(res).unwrap()))) } }) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index fbffc3c7e..170dc8824 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -126,6 +126,7 @@ def __iter__(self): def __next__(self): value = self.text_queue.get(timeout=self.timeout) + print("\n\n", value, "\n\n", file=sys.stderr) if value != self.stop_signal: return value diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index 1aafd654b..9301233ad 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -95,8 +95,10 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult .try_settle_with(&channel, move |mut cx| { let o = cx.empty_object(); if let Some(v) = v { - let v: String = v.expect("Error calling next on TransformerStream"); - let v = cx.string(v); + let v: Json = v.expect("Error calling next on TransformerStream"); + let v = v + .into_js_result(&mut cx) + .expect("Error converting rust Json to JavaScript Object"); let d = cx.boolean(false); o.set(&mut cx, "value", v) .expect("Error setting object value in transformer_sream_iterate_next"); diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index 2cf1bcf9c..d9988373b 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -72,7 +72,7 @@ impl TransformerStreamPython { if let Some(o) = ts.next().await { Ok(Some(Python::with_gil(|py| { o.expect("Error calling next on TransformerStream") - .to_object(py) + .into_py(py) }))) } else { Err(pyo3::exceptions::PyStopAsyncIteration::new_err( diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index 408797ef2..b86c12926 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -49,21 +49,17 @@ impl OpenSourceAI { Self { database_url } } - pub async fn chat_completions_create_async( + fn create_pipeline_model_name_parameters( &self, mut model: Json, - messages: Json, - max_tokens: Option, - temperature: Option, - n: Option, - ) -> anyhow::Result { - let (transformer_pipeline, model_name, model_parameters) = if model.is_object() { + ) -> anyhow::Result<(TransformerPipeline, String, Json)> { + if model.is_object() { let args = model.as_object_mut().unwrap(); let model_name = args .remove("model") .context("`model` is a required key in the model object")?; let model_name = model_name.as_str().context("`model` must be a string")?; - ( + Ok(( TransformerPipeline::new( "conversational", Some(model_name.to_string()), @@ -72,7 +68,7 @@ impl OpenSourceAI { ), model_name.to_string(), model, - ) + )) } else { let model_name = model .as_str() @@ -83,7 +79,7 @@ impl OpenSourceAI { mistralai/Mistral-7B-v0.1 "#, )?; - ( + Ok(( TransformerPipeline::new( "conversational", Some(real_model_name.to_string()), @@ -92,28 +88,44 @@ mistralai/Mistral-7B-v0.1 ), model_name.to_string(), parameters, - ) - }; + )) + } + } + + pub async fn chat_completions_create_stream_async( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + ) -> anyhow::Result<()> { + Ok(()) + } + + pub async fn chat_completions_create_async( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + ) -> anyhow::Result { + let (transformer_pipeline, model_name, model_parameters) = + self.create_pipeline_model_name_parameters(model)?; let max_tokens = max_tokens.unwrap_or(1000); let temperature = temperature.unwrap_or(0.8); let n = n.unwrap_or(1) as usize; - let to_hash = format!( - "{}{}{}{}", - model_parameters.to_string(), - max_tokens, - temperature, - n - ); + let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n); let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; - let messages: Vec = std::iter::repeat(messages).take(n).collect(); let choices = transformer_pipeline .transform( messages, Some( - serde_json::json!({ "max_length": max_tokens, "temperature": temperature }) + serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }) .into(), ), ) @@ -121,7 +133,7 @@ mistralai/Mistral-7B-v0.1 let choices: Vec = choices .as_array() .context("Error parsing return from TransformerPipeline")? - .into_iter() + .iter() .enumerate() .map(|(i, c)| { serde_json::json!({ @@ -157,7 +169,7 @@ mistralai/Mistral-7B-v0.1 pub fn chat_completions_create( &self, model: Json, - messages: Json, + messages: Vec, max_tokens: Option, temperature: Option, n: Option, @@ -177,14 +189,14 @@ mistralai/Mistral-7B-v0.1 mod tests { use super::*; - #[sqlx::test] - async fn can_open_source_ai_create() -> anyhow::Result<()> { + #[test] + fn can_open_source_ai_create() -> anyhow::Result<()> { let client = OpenSourceAI::new(None); - let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), serde_json::json!([ - {"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}, - {"role": "user", "content": "How many helicopters can a human eat in one sitting?"} - ]).into(), Some(1000), None, None)?; - assert!(results.as_array().is_some()); + let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(1000), None, Some(3))?; + assert!(results["choices"].as_array().is_some()); Ok(()) } } diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 1ec11808c..a5f6451b0 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -55,7 +55,7 @@ impl TransformerStream { } impl Stream for TransformerStream { - type Item = Result; + type Item = Result; fn poll_next( mut self: Pin<&mut Self>, @@ -106,7 +106,7 @@ impl Stream for TransformerStream { if !self.results.is_empty() { let r = self.results.pop_front().unwrap(); - Poll::Ready(Some(Ok(r.get::(0)))) + Poll::Ready(Some(Ok(r.get::(0)))) } else if self.done { Poll::Ready(None) } else { @@ -251,10 +251,10 @@ mod tests { internal_init_logger(None, None).ok(); let t = TransformerPipeline::new( "text-generation", - Some("TheBloke/zephyr-7B-beta-GGUF".to_string()), + Some("TheBloke/zephyr-7B-beta-GPTQ".to_string()), Some( serde_json::json!({ - "model_file": "zephyr-7b-beta.Q5_K_M.gguf", "model_type": "mistral" + "model_type": "mistral", "revision": "main", "device_map": "auto" }) .into(), ), From dedf4341884eea16d6ea6e64aa5a6e9677a9eae7 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:53:04 -0800 Subject: [PATCH 07/22] Remove outdated comment --- pgml-extension/src/bindings/transformers/transformers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 170dc8824..eaaa190a8 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -282,7 +282,6 @@ def __init__(self, model_name, **kwargs): def stream(self, input, **kwargs): streamer = None generation_kwargs = None - # Conversational does not work right now with left padded tokenizers. At least for gpt2, the apply_chat_template breaks it if self.task == "conversational": streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True From f3d8e1f7db5708c21331f48f05e2191322d31582 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 29 Nov 2023 12:08:50 -0800 Subject: [PATCH 08/22] Working OpenSourceAI with both sync and async options --- .../pgml/javascript/tests/jest.config.js | 3 +- .../javascript/tests/typescript-tests/test.ts | 66 ++++++++- pgml-sdks/pgml/python/tests/test.py | 69 +++++++++- pgml-sdks/pgml/src/languages/javascript.rs | 68 +++++++-- pgml-sdks/pgml/src/languages/python.rs | 61 +++++++-- pgml-sdks/pgml/src/open_source_ai.rs | 129 +++++++++++++++++- pgml-sdks/pgml/src/transformer_pipeline.rs | 59 +++++--- pgml-sdks/pgml/src/types.rs | 28 ++++ 8 files changed, 433 insertions(+), 50 deletions(-) diff --git a/pgml-sdks/pgml/javascript/tests/jest.config.js b/pgml-sdks/pgml/javascript/tests/jest.config.js index 7e67de525..7cf8a2c1e 100644 --- a/pgml-sdks/pgml/javascript/tests/jest.config.js +++ b/pgml-sdks/pgml/javascript/tests/jest.config.js @@ -4,5 +4,6 @@ export default { roots: [''], transform: { '^.+\\.tsx?$': 'ts-jest' - } + }, + testTimeout: 30000, } diff --git a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts index a802ef400..acc766bd8 100644 --- a/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts +++ b/pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts @@ -306,7 +306,7 @@ it("can transformer pipeline stream", async () => { // Test OpenSourceAI ////////////////////////////// /////////////////////////////////////////////////// -it("can open source ai create", async () => { +it("can open source ai create", () => { const client = pgml.newOpenSourceAI(); const results = client.chat_completions_create( "mistralai/Mistral-7B-v0.1", @@ -324,6 +324,70 @@ it("can open source ai create", async () => { expect(results.choices.length).toBeGreaterThan(0); }); + +it("can open source ai create async", async () => { + const client = pgml.newOpenSourceAI(); + const results = await client.chat_completions_create_async( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + expect(results.choices.length).toBeGreaterThan(0); +}); + + +it("can open source ai create stream", () => { + const client = pgml.newOpenSourceAI(); + const it = client.chat_completions_create_stream( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + let result = it.next(); + while (!result.done) { + expect(result.value.choices.length).toBeGreaterThan(0); + result = it.next(); + } +}); + +it("can open source ai create stream async", async () => { + const client = pgml.newOpenSourceAI(); + const it = await client.chat_completions_create_stream_async( + "mistralai/Mistral-7B-v0.1", + [ + { + role: "system", + content: "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + role: "user", + content: "How many helicopters can a human eat in one sitting?", + }, + ], + ); + let result = await it.next(); + while (!result.done) { + expect(result.value.choices.length).toBeGreaterThan(0); + result = await it.next(); + } +}); + /////////////////////////////////////////////////// // Test migrations //////////////////////////////// /////////////////////////////////////////////////// diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index fdf3725b9..b91da1578 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -321,7 +321,7 @@ async def test_transformer_pipeline_stream(): ################################################### -## Transformer Pipeline Tests ##################### +## OpenSourceAI tests ########################### ################################################### @@ -339,11 +339,76 @@ def test_open_source_ai_create(): "content": "How many helicopters can a human eat in one sitting?", }, ], - temperature=0.85 + temperature=0.85, ) assert len(results["choices"]) > 0 +@pytest.mark.asyncio +async def test_open_source_ai_create_async(): + client = pgml.OpenSourceAI() + results = await client.chat_completions_create_async( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + ) + import json + assert len(results["choices"]) > 0 + + +def test_open_source_ai_create_stream(): + client = pgml.OpenSourceAI() + results = client.chat_completions_create_stream( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + n=3, + ) + for c in results: + assert len(c["choices"]) > 0 + + +@pytest.mark.asyncio +async def test_open_source_ai_create_stream_async(): + client = pgml.OpenSourceAI() + results = await client.chat_completions_create_stream_async( + "mistralai/Mistral-7B-v0.1", + [ + { + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + }, + { + "role": "user", + "content": "How many helicopters can a human eat in one sitting?", + }, + ], + temperature=0.85, + n=3, + ) + import json + async for c in results: + assert len(c["choices"]) > 0 + + ################################################### ## Migration tests ################################ ################################################### diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index 9301233ad..d4999c1e4 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -1,12 +1,12 @@ use futures::StreamExt; use neon::prelude::*; use rust_bridge::javascript::{FromJsType, IntoJsResult}; +use std::cell::RefCell; use std::sync::Arc; use crate::{ pipeline::PipelineSyncData, - transformer_pipeline::TransformerStream, - types::{DateTime, Json}, + types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, }; //////////////////////////////////////////////////////////////////////////////// @@ -74,17 +74,17 @@ impl IntoJsResult for PipelineSyncData { } #[derive(Clone)] -struct TransformerStreamArcMutex(Arc>); +struct GeneralJsonAsyncIteratorArcMutex(Arc>); -impl Finalize for TransformerStreamArcMutex {} +impl Finalize for GeneralJsonAsyncIteratorArcMutex {} fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult { let this = cx.this(); - let s: Handle> = this + let s: Handle> = this .get(&mut cx, "s") .expect("Error getting self in transformer_stream_iterate_next"); - let ts: &TransformerStreamArcMutex = &s; - let ts: TransformerStreamArcMutex = ts.clone(); + let ts: &GeneralJsonAsyncIteratorArcMutex = &s; + let ts: GeneralJsonAsyncIteratorArcMutex = ts.clone(); let channel = cx.channel(); let (deferred, promise) = cx.promise(); @@ -95,7 +95,7 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult .try_settle_with(&channel, move |mut cx| { let o = cx.empty_object(); if let Some(v) = v { - let v: Json = v.expect("Error calling next on TransformerStream"); + let v: Json = v.expect("Error calling next on GeneralJsonAsyncIterator"); let v = v .into_js_result(&mut cx) .expect("Error converting rust Json to JavaScript Object"); @@ -116,8 +116,8 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult Ok(promise) } -impl IntoJsResult for TransformerStream { - type Output = JsObject; +impl IntoJsResult for GeneralJsonAsyncIterator { + type Output = JsValue; fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( self, cx: &mut C, @@ -125,11 +125,55 @@ impl IntoJsResult for TransformerStream { let o = cx.empty_object(); let f: Handle = JsFunction::new(cx, transform_stream_iterate_next)?; o.set(cx, "next", f)?; - let s = cx.boxed(TransformerStreamArcMutex(Arc::new( + let s = cx.boxed(GeneralJsonAsyncIteratorArcMutex(Arc::new( tokio::sync::Mutex::new(self), ))); o.set(cx, "s", s)?; - Ok(o) + Ok(o.as_value(cx)) + } +} + +struct GeneralJsonIteratorJavaScript(RefCell); + +impl Finalize for GeneralJsonIteratorJavaScript {} + +fn transform_iterate_next(mut cx: FunctionContext) -> JsResult { + let this = cx.this(); + let s: Handle> = this + .get(&mut cx, "s") + .expect("Error getting self in transformer_stream_iterate_next"); + let v = s.0.borrow_mut().next(); + let o = cx.empty_object(); + if let Some(v) = v { + let v: Json = v.expect("Error calling next on GeneralJsonAsyncIterator"); + let v = v + .into_js_result(&mut cx) + .expect("Error converting rust Json to JavaScript Object"); + let d = cx.boolean(false); + o.set(&mut cx, "value", v) + .expect("Error setting object value in transformer_sream_iterate_next"); + o.set(&mut cx, "done", d) + .expect("Error setting object value in transformer_sream_iterate_next"); + } else { + let d = cx.boolean(true); + o.set(&mut cx, "done", d) + .expect("Error setting object value in transformer_sream_iterate_next"); + } + Ok(o) +} + +impl IntoJsResult for GeneralJsonIterator { + type Output = JsValue; + fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>( + self, + cx: &mut C, + ) -> JsResult<'b, Self::Output> { + let o = cx.empty_object(); + let f: Handle = JsFunction::new(cx, transform_iterate_next)?; + o.set(cx, "next", f)?; + let s = cx.boxed(GeneralJsonIteratorJavaScript(RefCell::new(self))); + o.set(cx, "s", s)?; + Ok(o.as_value(cx)) } } diff --git a/pgml-sdks/pgml/src/languages/python.rs b/pgml-sdks/pgml/src/languages/python.rs index d9988373b..9d19b16bd 100644 --- a/pgml-sdks/pgml/src/languages/python.rs +++ b/pgml-sdks/pgml/src/languages/python.rs @@ -6,7 +6,10 @@ use std::sync::Arc; use rust_bridge::python::CustomInto; -use crate::{pipeline::PipelineSyncData, transformer_pipeline::TransformerStream, types::Json}; +use crate::{ + pipeline::PipelineSyncData, + types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, +}; //////////////////////////////////////////////////////////////////////////////// // Rust to PY ////////////////////////////////////////////////////////////////// @@ -55,12 +58,12 @@ impl IntoPy for PipelineSyncData { #[pyclass] #[derive(Clone)] -struct TransformerStreamPython { - wrapped: Arc>, +struct GeneralJsonAsyncIteratorPython { + wrapped: Arc>, } #[pymethods] -impl TransformerStreamPython { +impl GeneralJsonAsyncIteratorPython { fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { slf } @@ -71,7 +74,7 @@ impl TransformerStreamPython { let mut ts = ts.lock().await; if let Some(o) = ts.next().await { Ok(Some(Python::with_gil(|py| { - o.expect("Error calling next on TransformerStream") + o.expect("Error calling next on GeneralJsonAsyncIterator") .into_py(py) }))) } else { @@ -84,15 +87,47 @@ impl TransformerStreamPython { } } -impl IntoPy for TransformerStream { +impl IntoPy for GeneralJsonAsyncIterator { fn into_py(self, py: Python) -> PyObject { - let f: Py = Py::new( + let f: Py = Py::new( py, - TransformerStreamPython { + GeneralJsonAsyncIteratorPython { wrapped: Arc::new(tokio::sync::Mutex::new(self)), }, ) - .expect("Error converting TransformerStream to TransformerStreamPython"); + .expect("Error converting GeneralJsonAsyncIterator to GeneralJsonAsyncIteratorPython"); + f.to_object(py) + } +} + +#[pyclass] +struct GeneralJsonIteratorPython { + wrapped: GeneralJsonIterator, +} + +#[pymethods] +impl GeneralJsonIteratorPython { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>, py: Python) -> PyResult> { + if let Some(o) = slf.wrapped.next() { + let o = o.expect("Error calling next on GeneralJsonIterator"); + Ok(Some(o.into_py(py))) + } else { + Err(pyo3::exceptions::PyStopIteration::new_err( + "stream exhausted", + )) + } + } +} + +impl IntoPy for GeneralJsonIterator { + fn into_py(self, py: Python) -> PyObject { + let f: Py = + Py::new(py, GeneralJsonIteratorPython { wrapped: self }) + .expect("Error converting GeneralJsonIterator to GeneralJsonIteratorPython"); f.to_object(py) } } @@ -149,7 +184,13 @@ impl FromPyObject<'_> for PipelineSyncData { } } -impl FromPyObject<'_> for TransformerStream { +impl FromPyObject<'_> for GeneralJsonAsyncIterator { + fn extract(_ob: &PyAny) -> PyResult { + panic!("We must implement this, but this is impossible to be reached") + } +} + +impl FromPyObject<'_> for GeneralJsonIterator { fn extract(_ob: &PyAny) -> PyResult { panic!("We must implement this, but this is impossible to be reached") } diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index b86c12926..68246a04a 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -1,12 +1,16 @@ use anyhow::Context; +use futures::{StreamExt, Stream}; use rust_bridge::{alias, alias_methods}; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; -use crate::{types::Json, TransformerPipeline}; +use crate::{ + types::{GeneralJsonAsyncIterator, Json, GeneralJsonIterator}, + TransformerPipeline, get_or_set_runtime, +}; #[cfg(feature = "python")] -use crate::types::JsonPython; +use crate::types::{JsonPython, GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython}; #[derive(alias, Debug, Clone)] pub struct OpenSourceAI { @@ -43,7 +47,18 @@ fn try_model_nice_name_to_model_name_and_parameters( } } -#[alias_methods(new, chat_completions_create, chat_completions_create_async)] +struct AsyncToSyncJsonIterator(std::pin::Pin> + Send>>); + +impl Iterator for AsyncToSyncJsonIterator { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + let runtime = get_or_set_runtime(); + runtime.block_on(self.0.next()) + } +} + +#[alias_methods(new, chat_completions_create, chat_completions_create_async, chat_completions_create_stream, chat_completions_create_stream_async)] impl OpenSourceAI { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -99,8 +114,69 @@ mistralai/Mistral-7B-v0.1 max_tokens: Option, temperature: Option, n: Option, - ) -> anyhow::Result<()> { - Ok(()) + ) -> anyhow::Result { + let (transformer_pipeline, model_name, model_parameters) = + self.create_pipeline_model_name_parameters(model)?; + + let max_tokens = max_tokens.unwrap_or(1000); + let temperature = temperature.unwrap_or(0.8); + let n = n.unwrap_or(1) as usize; + let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n); + let md5_digest = md5::compute(to_hash.as_bytes()); + let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; + + let messages = serde_json::to_value(messages)?.into(); + let iterator = transformer_pipeline + .transform_stream( + messages, + Some( + serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }) + .into(), + ), + Some(1) + ) + .await?; + + let id = Uuid::new_v4().to_string(); + let iter = iterator.map(move |choices| { + let since_the_epoch = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + eprintln!("{:?}", choices); + Ok(serde_json::json!({ + "id": id.clone(), + "system_fingerprint": fingerprint.clone(), + "object": "chat.completion.chunk", + "created": since_the_epoch.as_secs(), + "model": model_name.clone(), + "choices": choices?.as_array().context("Error parsing choices from GeneralJsonAsyncIterator")?.iter().enumerate().map(|(i, c)| { + serde_json::json!({ + "index": i, + "delta": { + "role": "assistant", + "content": c + } + }) + // finish_reason goes here + }).collect::() + }) + .into()) + }); + + Ok(GeneralJsonAsyncIterator(Box::pin(iter))) + } + + pub fn chat_completions_create_stream( + &self, + model: Json, + messages: Vec, + max_tokens: Option, + temperature: Option, + n: Option, + ) -> anyhow::Result { + let runtime = crate::get_or_set_runtime(); + let iter = runtime.block_on(self.chat_completions_create_stream_async(model, messages, max_tokens, temperature, n))?; + Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator(Box::pin(iter))))) } pub async fn chat_completions_create_async( @@ -157,7 +233,7 @@ mistralai/Mistral-7B-v0.1 "model": model_name, "system_fingerprint": fingerprint, "choices": choices, - "usage": { + "usage": { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0 @@ -188,6 +264,7 @@ mistralai/Mistral-7B-v0.1 #[cfg(test)] mod tests { use super::*; + use futures::StreamExt; #[test] fn can_open_source_ai_create() -> anyhow::Result<()> { @@ -195,8 +272,46 @@ mod tests { let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), - ], Some(1000), None, Some(3))?; + ], Some(10), None, Some(3))?; assert!(results["choices"].as_array().is_some()); Ok(()) } + + #[sqlx::test] + fn can_open_source_ai_create_async() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let results = client.chat_completions_create_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3)).await?; + assert!(results["choices"].as_array().is_some()); + Ok(()) + } + + #[sqlx::test] + fn can_open_source_ai_create_stream_async() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3)).await?; + while let Some(o) = stream.next().await { + o?; + } + Ok(()) + } + + #[test] + fn can_open_source_ai_create_stream() -> anyhow::Result<()> { + let client = OpenSourceAI::new(None); + let iterator = client.chat_completions_create_stream(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ + serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), + serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), + ], Some(10), None, Some(3))?; + for o in iterator { + o?; + } + Ok(()) + } + } diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index a5f6451b0..00dd556f7 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -1,6 +1,6 @@ use anyhow::Context; use futures::Stream; -use rust_bridge::{alias, alias_manual, alias_methods}; +use rust_bridge::{alias, alias_methods}; use sqlx::{postgres::PgRow, Row}; use sqlx::{Postgres, Transaction}; use std::collections::VecDeque; @@ -16,14 +16,14 @@ pub struct TransformerPipeline { database_url: Option, } +use crate::types::GeneralJsonAsyncIterator; use crate::{get_or_initialize_pool, types::Json}; #[cfg(feature = "python")] -use crate::types::JsonPython; +use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython}; #[allow(clippy::type_complexity)] -#[derive(alias_manual)] -pub struct TransformerStream { +struct TransformerStream { transaction: Option>, future: Option, sqlx::Error>> + Send + 'static>>>, commit: Option> + Send + 'static>>>, @@ -55,7 +55,7 @@ impl TransformerStream { } impl Stream for TransformerStream { - type Item = Result; + type Item = anyhow::Result; fn poll_next( mut self: Pin<&mut Self>, @@ -179,25 +179,50 @@ impl TransformerPipeline { #[instrument(skip(self))] pub async fn transform_stream( &self, - input: &str, + input: Json, args: Option, batch_size: Option, - ) -> anyhow::Result { + ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let args = args.unwrap_or_default(); let batch_size = batch_size.unwrap_or(10); let mut transaction = pool.begin().await?; - sqlx::query( - "DECLARE c CURSOR FOR SELECT pgml.transform_stream(task => $1, input => $2, args => $3)", - ) - .bind(&self.task) - .bind(input) - .bind(&args) - .execute(&mut *transaction) - .await?; + // We set the task in the new constructor so we can unwrap here + if self.task["task"].as_str().unwrap() == "conversational" { + let inputs = input + .as_array() + .context("`input` to transformer_stream must be an array of objects")? + .to_vec(); + sqlx::query( + "DECLARE c CURSOR FOR SELECT pgml.transform_stream(task => $1, inputs => $2, args => $3)", + ) + .bind(&self.task) + .bind(inputs) + .bind(&args) + .execute(&mut *transaction) + .await?; + } else { + let input = input + .as_str() + .context( + "`input` to transformer_stream must be a string if task is not conversational", + )? + .to_string(); + sqlx::query( + "DECLARE c CURSOR FOR SELECT pgml.transform_stream(task => $1, input => $2, args => $3)", + ) + .bind(&self.task) + .bind(input) + .bind(&args) + .execute(&mut *transaction) + .await?; + } - Ok(TransformerStream::new(transaction, batch_size)) + Ok(GeneralJsonAsyncIterator(Box::pin(TransformerStream::new( + transaction, + batch_size, + )))) } } @@ -262,7 +287,7 @@ mod tests { ); let mut stream = t .transform_stream( - "AI is going to", + serde_json::json!("AI is going to").into(), Some( serde_json::json!({ "max_new_tokens": 10 diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index e69e2f42d..bdf7308a3 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,4 +1,5 @@ use anyhow::Context; +use futures::{Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; @@ -122,3 +123,30 @@ impl IntoTableNameAndSchema for String { .expect("Malformed table name in IntoTableNameAndSchema") } } + +#[derive(alias_manual)] +pub struct GeneralJsonAsyncIterator( + pub std::pin::Pin> + Send>>, +); + +impl Stream for GeneralJsonAsyncIterator { + type Item = anyhow::Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.0.poll_next_unpin(cx) + } +} + +#[derive(alias_manual)] +pub struct GeneralJsonIterator(pub Box> + Send>); + +impl Iterator for GeneralJsonIterator { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + self.0.next() + } +} From 2969c89c8d5a64c2eedcb36feccde9fb3c30358d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 29 Nov 2023 14:56:37 -0800 Subject: [PATCH 09/22] Cleaned up and tested well --- .../src/bindings/transformers/transformers.py | 29 +++-- pgml-sdks/pgml/src/builtins.rs | 2 +- pgml-sdks/pgml/src/languages/javascript.rs | 26 ++--- pgml-sdks/pgml/src/lib.rs | 1 - pgml-sdks/pgml/src/open_source_ai.rs | 101 ++++++++++++------ 5 files changed, 107 insertions(+), 52 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index eaaa190a8..124819518 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -126,7 +126,6 @@ def __iter__(self): def __next__(self): value = self.text_queue.get(timeout=self.timeout) - print("\n\n", value, "\n\n", file=sys.stderr) if value != self.stop_signal: return value @@ -286,9 +285,17 @@ def stream(self, input, **kwargs): streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True ) - input = self.tokenizer.apply_chat_template( - input, add_generation_prompt=True, tokenize=False - ) + if "chat_template" in kwargs: + input = self.tokenizer.apply_chat_template( + input, + add_generation_prompt=True, + tokenize=False, + chat_template=kwargs.pop("chat_template"), + ) + else: + input = self.tokenizer.apply_chat_template( + input, add_generation_prompt=True, tokenize=False + ) input = self.tokenizer(input, return_tensors="pt").to(self.model.device) generation_kwargs = dict(input, streamer=streamer, **kwargs) else: @@ -303,9 +310,17 @@ def stream(self, input, **kwargs): def __call__(self, inputs, **kwargs): if self.task == "conversational": - inputs = self.tokenizer.apply_chat_template( - inputs, add_generation_prompt=True, tokenize=False - ) + if "chat_template" in kwargs: + inputs = self.tokenizer.apply_chat_template( + inputs, + add_generation_prompt=True, + tokenize=False, + chat_template=kwargs.pop("chat_template"), + ) + else: + inputs = self.tokenizer.apply_chat_template( + inputs, add_generation_prompt=True, tokenize=False + ) inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) args = dict(inputs, **kwargs) outputs = self.model.generate(**args) diff --git a/pgml-sdks/pgml/src/builtins.rs b/pgml-sdks/pgml/src/builtins.rs index 188948c72..db023b951 100644 --- a/pgml-sdks/pgml/src/builtins.rs +++ b/pgml-sdks/pgml/src/builtins.rs @@ -101,7 +101,7 @@ mod tests { let query = "SELECT * from pgml.collections"; let results = builtins.query(query).fetch_all().await?; assert!(results.as_array().is_some()); - Ok(()) + Ok(()) } #[sqlx::test] diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index d4999c1e4..c9a09326d 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -74,17 +74,17 @@ impl IntoJsResult for PipelineSyncData { } #[derive(Clone)] -struct GeneralJsonAsyncIteratorArcMutex(Arc>); +struct GeneralJsonAsyncIteratorJavaScript(Arc>); -impl Finalize for GeneralJsonAsyncIteratorArcMutex {} +impl Finalize for GeneralJsonAsyncIteratorJavaScript {} fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult { let this = cx.this(); - let s: Handle> = this + let s: Handle> = this .get(&mut cx, "s") .expect("Error getting self in transformer_stream_iterate_next"); - let ts: &GeneralJsonAsyncIteratorArcMutex = &s; - let ts: GeneralJsonAsyncIteratorArcMutex = ts.clone(); + let ts: &GeneralJsonAsyncIteratorJavaScript = &s; + let ts: GeneralJsonAsyncIteratorJavaScript = ts.clone(); let channel = cx.channel(); let (deferred, promise) = cx.promise(); @@ -101,13 +101,13 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult .expect("Error converting rust Json to JavaScript Object"); let d = cx.boolean(false); o.set(&mut cx, "value", v) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_sream_iterate_next"); o.set(&mut cx, "done", d) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_sream_iterate_next"); } else { let d = cx.boolean(true); o.set(&mut cx, "done", d) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_sream_iterate_next"); } Ok(o) }) @@ -125,7 +125,7 @@ impl IntoJsResult for GeneralJsonAsyncIterator { let o = cx.empty_object(); let f: Handle = JsFunction::new(cx, transform_stream_iterate_next)?; o.set(cx, "next", f)?; - let s = cx.boxed(GeneralJsonAsyncIteratorArcMutex(Arc::new( + let s = cx.boxed(GeneralJsonAsyncIteratorJavaScript(Arc::new( tokio::sync::Mutex::new(self), ))); o.set(cx, "s", s)?; @@ -141,7 +141,7 @@ fn transform_iterate_next(mut cx: FunctionContext) -> JsResult { let this = cx.this(); let s: Handle> = this .get(&mut cx, "s") - .expect("Error getting self in transformer_stream_iterate_next"); + .expect("Error getting self in transform_iterate_next"); let v = s.0.borrow_mut().next(); let o = cx.empty_object(); if let Some(v) = v { @@ -151,13 +151,13 @@ fn transform_iterate_next(mut cx: FunctionContext) -> JsResult { .expect("Error converting rust Json to JavaScript Object"); let d = cx.boolean(false); o.set(&mut cx, "value", v) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_iterate_next"); o.set(&mut cx, "done", d) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_iterate_next"); } else { let d = cx.boolean(true); o.set(&mut cx, "done", d) - .expect("Error setting object value in transformer_sream_iterate_next"); + .expect("Error setting object value in transform_iterate_next"); } Ok(o) } diff --git a/pgml-sdks/pgml/src/lib.rs b/pgml-sdks/pgml/src/lib.rs index b115da69c..f7ef4ceec 100644 --- a/pgml-sdks/pgml/src/lib.rs +++ b/pgml-sdks/pgml/src/lib.rs @@ -765,7 +765,6 @@ mod tests { .filter(filter) .fetch_all() .await?; - println!("{:?}", results); assert_eq!(results.len(), expected_result_count); } diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index 68246a04a..c4e740859 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -1,16 +1,17 @@ use anyhow::Context; -use futures::{StreamExt, Stream}; +use futures::{Stream, StreamExt}; use rust_bridge::{alias, alias_methods}; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; use crate::{ - types::{GeneralJsonAsyncIterator, Json, GeneralJsonIterator}, - TransformerPipeline, get_or_set_runtime, + get_or_set_runtime, + types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json}, + TransformerPipeline, }; #[cfg(feature = "python")] -use crate::types::{JsonPython, GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython}; +use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython}; #[derive(alias, Debug, Clone)] pub struct OpenSourceAI { @@ -43,10 +44,26 @@ fn try_model_nice_name_to_model_name_and_parameters( }) .into(), )), + "PygmalionAI/mythalion-13b" => Some(( + "TheBloke/Mythalion-13B-GPTQ", + serde_json::json!({ + "model": "TheBloke/Mythalion-13B-GPTQ", + "device_map": "auto", + "revision": "main" + }) + .into(), + )), _ => None, } } +fn try_get_model_chat_template(model_name: &str) -> Option<&'static str> { + match model_name { + "PygmalionAI/mythalion-13b" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'model' %}\n{{ '<|model|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|model|>' }}\n{% endif %}\n{% endfor %}"), + _ => None + } +} + struct AsyncToSyncJsonIterator(std::pin::Pin> + Send>>); impl Iterator for AsyncToSyncJsonIterator { @@ -58,7 +75,13 @@ impl Iterator for AsyncToSyncJsonIterator { } } -#[alias_methods(new, chat_completions_create, chat_completions_create_async, chat_completions_create_stream, chat_completions_create_stream_async)] +#[alias_methods( + new, + chat_completions_create, + chat_completions_create_async, + chat_completions_create_stream, + chat_completions_create_stream_async +)] impl OpenSourceAI { pub fn new(database_url: Option) -> Self { Self { database_url } @@ -114,6 +137,7 @@ mistralai/Mistral-7B-v0.1 max_tokens: Option, temperature: Option, n: Option, + chat_template: Option, ) -> anyhow::Result { let (transformer_pipeline, model_name, model_parameters) = self.create_pipeline_model_name_parameters(model)?; @@ -125,16 +149,19 @@ mistralai/Mistral-7B-v0.1 let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; + let mut args = serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + if let Some(t) = chat_template + .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) + { + args.as_object_mut().unwrap().insert( + "chat_template".to_string(), + serde_json::to_value(t).unwrap(), + ); + } + let messages = serde_json::to_value(messages)?.into(); let iterator = transformer_pipeline - .transform_stream( - messages, - Some( - serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }) - .into(), - ), - Some(1) - ) + .transform_stream(messages, Some(args.into()), Some(1)) .await?; let id = Uuid::new_v4().to_string(); @@ -142,7 +169,6 @@ mistralai/Mistral-7B-v0.1 let since_the_epoch = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Time went backwards"); - eprintln!("{:?}", choices); Ok(serde_json::json!({ "id": id.clone(), "system_fingerprint": fingerprint.clone(), @@ -155,9 +181,8 @@ mistralai/Mistral-7B-v0.1 "delta": { "role": "assistant", "content": c - } + } }) - // finish_reason goes here }).collect::() }) .into()) @@ -172,11 +197,21 @@ mistralai/Mistral-7B-v0.1 messages: Vec, max_tokens: Option, temperature: Option, + chat_template: Option, n: Option, ) -> anyhow::Result { let runtime = crate::get_or_set_runtime(); - let iter = runtime.block_on(self.chat_completions_create_stream_async(model, messages, max_tokens, temperature, n))?; - Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator(Box::pin(iter))))) + let iter = runtime.block_on(self.chat_completions_create_stream_async( + model, + messages, + max_tokens, + temperature, + n, + chat_template, + ))?; + Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator( + Box::pin(iter), + )))) } pub async fn chat_completions_create_async( @@ -186,6 +221,7 @@ mistralai/Mistral-7B-v0.1 max_tokens: Option, temperature: Option, n: Option, + chat_template: Option, ) -> anyhow::Result { let (transformer_pipeline, model_name, model_parameters) = self.create_pipeline_model_name_parameters(model)?; @@ -197,14 +233,18 @@ mistralai/Mistral-7B-v0.1 let md5_digest = md5::compute(to_hash.as_bytes()); let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?; + let mut args = serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }); + if let Some(t) = chat_template + .or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string())) + { + args.as_object_mut().unwrap().insert( + "chat_template".to_string(), + serde_json::to_value(t).unwrap(), + ); + } + let choices = transformer_pipeline - .transform( - messages, - Some( - serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n }) - .into(), - ), - ) + .transform(messages, Some(args.into())) .await?; let choices: Vec = choices .as_array() @@ -249,6 +289,7 @@ mistralai/Mistral-7B-v0.1 max_tokens: Option, temperature: Option, n: Option, + chat_template: Option, ) -> anyhow::Result { let runtime = crate::get_or_set_runtime(); runtime.block_on(self.chat_completions_create_async( @@ -257,6 +298,7 @@ mistralai/Mistral-7B-v0.1 max_tokens, temperature, n, + chat_template, )) } } @@ -272,7 +314,7 @@ mod tests { let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), - ], Some(10), None, Some(3))?; + ], Some(10), None, Some(3), None)?; assert!(results["choices"].as_array().is_some()); Ok(()) } @@ -283,7 +325,7 @@ mod tests { let results = client.chat_completions_create_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), - ], Some(10), None, Some(3)).await?; + ], Some(10), None, Some(3), None).await?; assert!(results["choices"].as_array().is_some()); Ok(()) } @@ -294,7 +336,7 @@ mod tests { let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), - ], Some(10), None, Some(3)).await?; + ], Some(10), None, Some(3), None).await?; while let Some(o) = stream.next().await { o?; } @@ -307,11 +349,10 @@ mod tests { let iterator = client.chat_completions_create_stream(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![ serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(), serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(), - ], Some(10), None, Some(3))?; + ], Some(10), None, Some(3), None)?; for o in iterator { o?; } Ok(()) } - } From 3b2674362032fc81cd536afe2d0eaa3e267a732d Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 29 Nov 2023 15:15:46 -0800 Subject: [PATCH 10/22] Completely removed the GPTQ pipeline as it is no longer necessary --- .../src/bindings/transformers/transformers.py | 30 ------------------- pgml-sdks/pgml/src/open_source_ai.rs | 2 +- 2 files changed, 1 insertion(+), 31 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 124819518..fed76ace5 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -195,30 +195,6 @@ def __next__(self): return v -class GGMLPipeline(object): - def __init__(self, model_name, **task): - import ctransformers - - task.pop("model") - task.pop("task") - task.pop("device") - self.model = ctransformers.AutoModelForCausalLM.from_pretrained( - model_name, **task - ) - self.tokenizer = None - self.task = "text-generation" - - def stream(self, inputs, **kwargs): - output = self.model(inputs, stream=True, **kwargs) - return ThreadedGeneratorIterator(output, inputs) - - def __call__(self, inputs, **kwargs): - outputs = [] - for input in inputs: - outputs.append(self.model(input, **kwargs)) - return outputs - - class StandardPipeline(object): def __init__(self, model_name, **kwargs): # the default pipeline constructor doesn't pass all the kwargs (particularly load_in_4bit) @@ -370,12 +346,6 @@ def create_pipeline(task): lower = None if lower and ("-ggml" in lower or "-gguf" in lower): pipe = GGMLPipeline(model_name, **task) - elif ( - lower - and "-gptq" in lower - and not (model_type == "mistral" or model_type == "llama") - ): - pipe = GPTQPipeline(model_name, **task) else: try: pipe = StandardPipeline(model_name, **task) diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index c4e740859..5a8501e1a 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -33,7 +33,7 @@ fn try_model_nice_name_to_model_name_and_parameters( }) .into(), )), - "Llama-2-7b-chat-hf" => Some(( + "meta-llama/Llama-2-7b-chat-hf" => Some(( "TheBloke/Llama-2-7B-Chat-GPTQ", serde_json::json!({ "task": "conversational", From cf1afc6b29e18e13a157e3bfc4dfe7577adab301 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Wed, 29 Nov 2023 15:17:43 -0800 Subject: [PATCH 11/22] Removed unnecessary python imports --- pgml-extension/src/bindings/transformers/transformers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index fed76ace5..41dad64a6 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -40,8 +40,6 @@ PegasusTokenizer, TrainingArguments, Trainer, - TextStreamer, - Conversation, ) from threading import Thread from typing import Optional From accd1591c3f0cbc2f94eebd4c9ed3d18128183b3 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Thu, 30 Nov 2023 09:35:42 -0800 Subject: [PATCH 12/22] Removed universal debugger output --- pgml-extension/src/bindings/transformers/transform.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index 9bd81b77a..fa03984d9 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -32,9 +32,7 @@ impl Iterator for TransformStreamIterator { if res.is_none() { Ok(None) } else { - eprintln!("\nHERE WE ARE!\n"); let res: Vec = res.extract()?; - eprintln!("\nYUP WE DIDNT GET HERE\n"); Ok(Some(JsonB(serde_json::to_value(res).unwrap()))) } }) From e5eccecf3ef1cdf178961d7dc47b0b3f9296a8db Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:13:25 -0800 Subject: [PATCH 13/22] Finalized models in SDK for open source ai --- pgml-sdks/pgml/python/tests/test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index b91da1578..f3b1fbec9 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -404,7 +404,6 @@ async def test_open_source_ai_create_stream_async(): temperature=0.85, n=3, ) - import json async for c in results: assert len(c["choices"]) > 0 From 95e1e9abf318bd6b061fb414550ed2834d5ef824 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:14:00 -0800 Subject: [PATCH 14/22] Updated to work with hugging face tokens --- .../src/bindings/transformers/transformers.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 41dad64a6..46fc9946f 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -200,6 +200,11 @@ def __init__(self, model_name, **kwargs): # but that is only possible when the task is passed in, since if you pass the model # to the pipeline constructor, the task will no longer be inferred from the default... + # See: https://huggingface.co/docs/hub/security-tokens + # This renaming is for backwards compatability + if "use_auth_token" in kwargs: + kwargs["token"] = kwargs.pop("use_auth_token") + if ( "task" in kwargs and model_name is not None @@ -230,9 +235,9 @@ def __init__(self, model_name, **kwargs): else: raise PgMLException(f"Unhandled task: {self.task}") - if "use_auth_token" in kwargs: + if "token" in kwargs: self.tokenizer = AutoTokenizer.from_pretrained( - model_name, use_auth_token=kwargs["use_auth_token"] + model_name, use_auth_token=kwargs["token"] ) else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -302,18 +307,6 @@ def __call__(self, inputs, **kwargs): outputs = outputs[:, inputs["input_ids"].shape[1] :] outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return outputs - - # I don't think conversations support num_responses and/or maybe num_beams - # Also this is not processed in parallel / truly batched it seems - # num_conversations = 1 - # if "num_return_sequences" in kwargs: - # num_conversations = kwargs.pop("num_return_sequences") - # conversations = [Conversation(inputs) for _ in range(0, num_conversations)] - # conversations = self.pipe(conversations, **kwargs) - # outputs = [] - # for conversation in conversations: - # outputs.append(conversation.messages[-1]["content"]) - # return outputs else: return self.pipe(inputs, **kwargs) From c80817b7089ce852ea865cade34af103ace9808f Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:14:27 -0800 Subject: [PATCH 15/22] Finalized models in SDK for open source ai --- pgml-sdks/pgml/src/open_source_ai.rs | 117 ++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 13 deletions(-) diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index 5a8501e1a..c73218894 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -22,43 +22,130 @@ fn try_model_nice_name_to_model_name_and_parameters( model_name: &str, ) -> Option<(&'static str, Json)> { match model_name { - "mistralai/Mistral-7B-v0.1" => Some(( - "TheBloke/zephyr-7B-beta-GPTQ", + // Not all models will necessarily have the same parameters / naming relation but they happen to now + "mistralai/Mistral-7B-Instruct-v0.1" => Some(( + "mistralai/Mistral-7B-Instruct-v0.1", serde_json::json!({ "task": "conversational", - "model": "TheBloke/zephyr-7B-beta-GPTQ", + "model": "mistralai/Mistral-7B-Instruct-v0.1", "device_map": "auto", - "revision": "main", - "model_type": "mistral" + "torch_dtype": "bfloat16" }) .into(), )), - "meta-llama/Llama-2-7b-chat-hf" => Some(( + + "HuggingFaceH4/zephyr-7b-beta" => Some(( + "HuggingFaceH4/zephyr-7b-beta", + serde_json::json!({ + "task": "conversational", + "model": "HuggingFaceH4/zephyr-7b-beta", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "TheBloke/Llama-2-7B-Chat-GPTQ" => Some(( "TheBloke/Llama-2-7B-Chat-GPTQ", serde_json::json!({ "task": "conversational", - "model": "TheBloke/zephyr-7B-beta-GPTQ", + "model": "TheBloke/Llama-2-7B-Chat-GPTQ", + "device_map": "auto", + "revision": "main" + }) + .into(), + )), + + "teknium/OpenHermes-2.5-Mistral-7B" => Some(( + "teknium/OpenHermes-2.5-Mistral-7B", + serde_json::json!({ + "task": "conversational", + "model": "teknium/OpenHermes-2.5-Mistral-7B", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Open-Orca/Mistral-7B-OpenOrca" => Some(( + "Open-Orca/Mistral-7B-OpenOrca", + serde_json::json!({ + "task": "conversational", + "model": "Open-Orca/Mistral-7B-OpenOrca", "device_map": "auto", - "revision": "main", - "model_type": "llama" + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Undi95/Toppy-M-7B" => Some(( + "Undi95/Toppy-M-7B", + serde_json::json!({ + "model": "Undi95/Toppy-M-7B", + "device_map": "auto", + "torch_dtype": "bfloat16" }) .into(), )), + + "Undi95/ReMM-SLERP-L2-13B" => Some(( + "Undi95/ReMM-SLERP-L2-13B", + serde_json::json!({ + "model": "Undi95/ReMM-SLERP-L2-13B", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Gryphe/MythoMax-L2-13b" => Some(( + "Gryphe/MythoMax-L2-13b", + serde_json::json!({ + "model": "Gryphe/MythoMax-L2-13b", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + "PygmalionAI/mythalion-13b" => Some(( - "TheBloke/Mythalion-13B-GPTQ", + "PygmalionAI/mythalion-13b", + serde_json::json!({ + "model": "PygmalionAI/mythalion-13b", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "deepseek-ai/deepseek-llm-7b-chat" => Some(( + "deepseek-ai/deepseek-llm-7b-chat", + serde_json::json!({ + "model": "deepseek-ai/deepseek-llm-7b-chat", + "device_map": "auto", + "torch_dtype": "bfloat16" + }) + .into(), + )), + + "Phind/Phind-CodeLlama-34B-v2" => Some(( + "Phind/Phind-CodeLlama-34B-v2", serde_json::json!({ - "model": "TheBloke/Mythalion-13B-GPTQ", + "model": "Phind/Phind-CodeLlama-34B-v2", "device_map": "auto", - "revision": "main" + "torch_dtype": "bfloat16" }) .into(), )), + _ => None, } } fn try_get_model_chat_template(model_name: &str) -> Option<&'static str> { match model_name { + // Any Alpaca instruct tuned model + "Undi95/Toppy-M-7B" | "Undi95/ReMM-SLERP-L2-13B" | "Gryphe/MythoMax-L2-13b" | "Phind/Phind-CodeLlama-34B-v2" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### Instruction:\n' + message['content'] + '\n'}}\n{% elif message['role'] == 'system' %}\n{{ message['content'] + '\n'}}\n{% elif message['role'] == 'model' %}\n{{ '### Response:>\n' + message['content'] + eos_token + '\n'}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Response:' }}\n{% endif %}\n{% endfor %}"), "PygmalionAI/mythalion-13b" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'model' %}\n{{ '<|model|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|model|>' }}\n{% endif %}\n{% endfor %}"), _ => None } @@ -130,6 +217,7 @@ mistralai/Mistral-7B-v0.1 } } + #[allow(clippy::too_many_arguments)] pub async fn chat_completions_create_stream_async( &self, model: Json, @@ -191,14 +279,15 @@ mistralai/Mistral-7B-v0.1 Ok(GeneralJsonAsyncIterator(Box::pin(iter))) } + #[allow(clippy::too_many_arguments)] pub fn chat_completions_create_stream( &self, model: Json, messages: Vec, max_tokens: Option, temperature: Option, - chat_template: Option, n: Option, + chat_template: Option, ) -> anyhow::Result { let runtime = crate::get_or_set_runtime(); let iter = runtime.block_on(self.chat_completions_create_stream_async( @@ -214,6 +303,7 @@ mistralai/Mistral-7B-v0.1 )))) } + #[allow(clippy::too_many_arguments)] pub async fn chat_completions_create_async( &self, model: Json, @@ -282,6 +372,7 @@ mistralai/Mistral-7B-v0.1 .into()) } + #[allow(clippy::too_many_arguments)] pub fn chat_completions_create( &self, model: Json, From 9a3ca9133b66e4cda6aba3219b81ca00fdf01ac6 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:16:14 -0800 Subject: [PATCH 16/22] Removed unnecessary comment --- pgml-sdks/pgml/src/open_source_ai.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/pgml-sdks/pgml/src/open_source_ai.rs b/pgml-sdks/pgml/src/open_source_ai.rs index c73218894..18adde288 100644 --- a/pgml-sdks/pgml/src/open_source_ai.rs +++ b/pgml-sdks/pgml/src/open_source_ai.rs @@ -22,7 +22,6 @@ fn try_model_nice_name_to_model_name_and_parameters( model_name: &str, ) -> Option<(&'static str, Json)> { match model_name { - // Not all models will necessarily have the same parameters / naming relation but they happen to now "mistralai/Mistral-7B-Instruct-v0.1" => Some(( "mistralai/Mistral-7B-Instruct-v0.1", serde_json::json!({ From 4eb88f8cb40262ddd18b8e6b76e9651aafa2f5b4 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:20:46 -0800 Subject: [PATCH 17/22] Put back the GGML pipeline and removed the GPTQ pipeline earlier commit had it backwards --- .../src/bindings/transformers/transformers.py | 39 ++++++------------- 1 file changed, 11 insertions(+), 28 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 46fc9946f..5c6078785 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -127,44 +127,27 @@ def __next__(self): if value != self.stop_signal: return value - -class GPTQPipeline(object): +class GGMLPipeline(object): def __init__(self, model_name, **task): - from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig - from huggingface_hub import snapshot_download - - model_path = snapshot_download(model_name) + import ctransformers - quantized_config = BaseQuantizeConfig.from_pretrained(model_path) - self.model = AutoGPTQForCausalLM.from_quantized( - model_path, quantized_config=quantized_config, **task + task.pop("model") + task.pop("task") + task.pop("device") + self.model = ctransformers.AutoModelForCausalLM.from_pretrained( + model_name, **task ) - if "use_fast_tokenizer" in task: - self.tokenizer = AutoTokenizer.from_pretrained( - model_path, use_fast=task.pop("use_fast_tokenizer") - ) - else: - self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.tokenizer = None self.task = "text-generation" def stream(self, inputs, **kwargs): - streamer = TextIteratorStreamer(self.tokenizer) - inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) - generation_kwargs = dict(inputs, streamer=streamer, **kwargs) - thread = Thread(target=self.model.generate, kwargs=generation_kwargs) - thread.start() - return streamer + output = self.model(inputs[0], stream=True, **kwargs) + return ThreadedGeneratorIterator(output, inputs[0]) def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: - tokens = ( - self.tokenizer(input, return_tensors="pt") - .to(self.model.device) - .input_ids - ) - token_ids = self.model.generate(input_ids=tokens, **kwargs)[0] - outputs.append(self.tokenizer.decode(token_ids)) + outputs.append(self.model(input, **kwargs)) return outputs From 93e7ffbafc0a093626b586d46352224dd63f47fa Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 12:14:51 -0800 Subject: [PATCH 18/22] Changed some error messages --- pgml-extension/src/api.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index b6ee865bf..3bf663026 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -646,7 +646,7 @@ pub fn transform_conversational_json( .is_some_and(|v| v == "conversational") { error!( - "ARRAY[]::JSONB inputs for transformer should only be used with a conversational task" + "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" ); } match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { @@ -666,7 +666,7 @@ pub fn transform_conversational_string( ) -> JsonB { if task != "conversational" { error!( - "ARRAY[]::JSONB inputs for transformer should only be used with a conversational task" + "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" ); } let task_json = json!({ "task": task }); @@ -725,7 +725,7 @@ pub fn transform_stream_conversational_json( .is_some_and(|v| v == "conversational") { error!( - "ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task" + "ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task" ); } // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call @@ -747,7 +747,7 @@ pub fn transform_stream_conversational_string( ) -> SetOfIterator<'static, JsonB> { if task != "conversational" { error!( - "ARRAY::JSONB inputs for transformer_stream should only be used with a conversational task" + "ARRAY::JSONB inputs for transform_stream should only be used with a conversational task" ); } let task_json = json!({ "task": task }); From 73ee33ab2770f39e624eef80f527f9e8ddc4484c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 12:32:00 -0800 Subject: [PATCH 19/22] Added migration for 2.8.1 --- pgml-extension/sql/pgml--2.8.0--2.8.1.sql | 65 +++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 pgml-extension/sql/pgml--2.8.0--2.8.1.sql diff --git a/pgml-extension/sql/pgml--2.8.0--2.8.1.sql b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql new file mode 100644 index 000000000..75ed891a4 --- /dev/null +++ b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql @@ -0,0 +1,65 @@ +-- pgml::api::transform_conversational_json +CREATE OR REPLACE FUNCTION pgml."transform"( + "task" jsonb, /* pgrx::datum::json::JsonB */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS jsonb /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_conversational_json_wrapper'; + +-- pgml::api::transform_conversational_string +CREATE OR REPLACE FUNCTION pgml."transform"( + "task" TEXT, /* alloc::string::String */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS jsonb /* alloc::string::String */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_conversational_string_wrapper'; + +-- pgml::api::transform_stream +CREATE OR REPLACE FUNCTION pgml."transform_stream"( + "task" TEXT, /* alloc::string::String */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "input" TEXT DEFAULT '', /* &str */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper_wrapper'; + +-- pgml::api::transform_stream +CREATE OR REPLACE FUNCTION pgml."transform_stream"( + "task" jsonb, /* pgrx::datum::json::JsonB */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "input" TEXT DEFAULT '', /* &str */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper_wrapper'; + +-- pgml::api::transform_stream_conversational_json +CREATE OR REPLACE FUNCTION pgml."transform_stream"( + "task" TEXT, /* alloc::string::String */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_conversational_string_wrapper'; + +-- pgml::api::transform_stream_conversational_string +CREATE OR REPLACE FUNCTION pgml."transform_stream"( + "task" jsonb, /* pgrx::datum::json::JsonB */ + "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ + "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "cache" bool DEFAULT false /* bool */ +) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ +IMMUTABLE STRICT PARALLEL SAFE +LANGUAGE c /* Rust */ +AS 'MODULE_PATHNAME', 'transform_stream_coversational_json_wrapper'; From 0201880e28a987093dadfe7bbe1d37f08b7a60f9 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 12:44:08 -0800 Subject: [PATCH 20/22] Working migration file --- pgml-extension/sql/pgml--2.8.0--2.8.1.sql | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pgml-extension/sql/pgml--2.8.0--2.8.1.sql b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql index 75ed891a4..ba46bdd75 100644 --- a/pgml-extension/sql/pgml--2.8.0--2.8.1.sql +++ b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql @@ -2,7 +2,7 @@ CREATE OR REPLACE FUNCTION pgml."transform"( "task" jsonb, /* pgrx::datum::json::JsonB */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ - "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ "cache" bool DEFAULT false /* bool */ ) RETURNS jsonb /* alloc::string::String */ IMMUTABLE STRICT PARALLEL SAFE @@ -13,14 +13,15 @@ AS 'MODULE_PATHNAME', 'transform_conversational_json_wrapper'; CREATE OR REPLACE FUNCTION pgml."transform"( "task" TEXT, /* alloc::string::String */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ - "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ "cache" bool DEFAULT false /* bool */ ) RETURNS jsonb /* alloc::string::String */ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c /* Rust */ AS 'MODULE_PATHNAME', 'transform_conversational_string_wrapper'; --- pgml::api::transform_stream +-- pgml::api::transform_stream_string +DROP FUNCTION IF EXISTS pgml."transform_stream"(text,jsonb,text,boolean); CREATE OR REPLACE FUNCTION pgml."transform_stream"( "task" TEXT, /* alloc::string::String */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ @@ -29,9 +30,10 @@ CREATE OR REPLACE FUNCTION pgml."transform_stream"( ) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c /* Rust */ -AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper_wrapper'; +AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper'; --- pgml::api::transform_stream +-- pgml::api::transform_stream_json +DROP FUNCTION IF EXISTS pgml."transform_stream"(jsonb,jsonb,text,boolean); CREATE OR REPLACE FUNCTION pgml."transform_stream"( "task" jsonb, /* pgrx::datum::json::JsonB */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ @@ -40,13 +42,13 @@ CREATE OR REPLACE FUNCTION pgml."transform_stream"( ) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c /* Rust */ -AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper_wrapper'; +AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper'; -- pgml::api::transform_stream_conversational_json CREATE OR REPLACE FUNCTION pgml."transform_stream"( "task" TEXT, /* alloc::string::String */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ - "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ "cache" bool DEFAULT false /* bool */ ) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ IMMUTABLE STRICT PARALLEL SAFE @@ -57,9 +59,9 @@ AS 'MODULE_PATHNAME', 'transform_stream_conversational_string_wrapper'; CREATE OR REPLACE FUNCTION pgml."transform_stream"( "task" jsonb, /* pgrx::datum::json::JsonB */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ - "inputs" jsonb[] DEFAULT 'ARRAY[]::JSONB[]', /* Vec */ + "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ "cache" bool DEFAULT false /* bool */ ) RETURNS SETOF jsonb /* pgrx::datum::json::JsonB */ IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c /* Rust */ -AS 'MODULE_PATHNAME', 'transform_stream_coversational_json_wrapper'; +AS 'MODULE_PATHNAME', 'transform_stream_conversational_json_wrapper'; From 47c18d6e094ca1765979e5815a9baf7b1985f0e3 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 12:51:46 -0800 Subject: [PATCH 21/22] Really working migration file --- pgml-extension/sql/pgml--2.8.0--2.8.1.sql | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pgml-extension/sql/pgml--2.8.0--2.8.1.sql b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql index ba46bdd75..f5d364156 100644 --- a/pgml-extension/sql/pgml--2.8.0--2.8.1.sql +++ b/pgml-extension/sql/pgml--2.8.0--2.8.1.sql @@ -1,5 +1,5 @@ -- pgml::api::transform_conversational_json -CREATE OR REPLACE FUNCTION pgml."transform"( +CREATE FUNCTION pgml."transform"( "task" jsonb, /* pgrx::datum::json::JsonB */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ @@ -10,7 +10,7 @@ LANGUAGE c /* Rust */ AS 'MODULE_PATHNAME', 'transform_conversational_json_wrapper'; -- pgml::api::transform_conversational_string -CREATE OR REPLACE FUNCTION pgml."transform"( +CREATE FUNCTION pgml."transform"( "task" TEXT, /* alloc::string::String */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ @@ -22,7 +22,7 @@ AS 'MODULE_PATHNAME', 'transform_conversational_string_wrapper'; -- pgml::api::transform_stream_string DROP FUNCTION IF EXISTS pgml."transform_stream"(text,jsonb,text,boolean); -CREATE OR REPLACE FUNCTION pgml."transform_stream"( +CREATE FUNCTION pgml."transform_stream"( "task" TEXT, /* alloc::string::String */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ "input" TEXT DEFAULT '', /* &str */ @@ -34,7 +34,7 @@ AS 'MODULE_PATHNAME', 'transform_stream_string_wrapper'; -- pgml::api::transform_stream_json DROP FUNCTION IF EXISTS pgml."transform_stream"(jsonb,jsonb,text,boolean); -CREATE OR REPLACE FUNCTION pgml."transform_stream"( +CREATE FUNCTION pgml."transform_stream"( "task" jsonb, /* pgrx::datum::json::JsonB */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ "input" TEXT DEFAULT '', /* &str */ @@ -45,7 +45,7 @@ LANGUAGE c /* Rust */ AS 'MODULE_PATHNAME', 'transform_stream_json_wrapper'; -- pgml::api::transform_stream_conversational_json -CREATE OR REPLACE FUNCTION pgml."transform_stream"( +CREATE FUNCTION pgml."transform_stream"( "task" TEXT, /* alloc::string::String */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ @@ -56,7 +56,7 @@ LANGUAGE c /* Rust */ AS 'MODULE_PATHNAME', 'transform_stream_conversational_string_wrapper'; -- pgml::api::transform_stream_conversational_string -CREATE OR REPLACE FUNCTION pgml."transform_stream"( +CREATE FUNCTION pgml."transform_stream"( "task" jsonb, /* pgrx::datum::json::JsonB */ "args" jsonb DEFAULT '{}', /* pgrx::datum::json::JsonB */ "inputs" jsonb[] DEFAULT ARRAY[]::JSONB[], /* Vec */ From fb3f7f70b6a6981658880b4269f13d0d3981dc12 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 12:59:05 -0800 Subject: [PATCH 22/22] Bumped version --- pgml-extension/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index ab3411447..a4da7bcbe 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "2.8.0" +version = "2.8.1" edition = "2021" [lib]