Skip to content

Commit 3ff2f07

Browse files
committed
Added OpenSourceAI and conversational support in the extension
1 parent c20f517 commit 3ff2f07

File tree

10 files changed

+410
-34
lines changed

10 files changed

+410
-34
lines changed

pgml-extension/src/api.rs

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,37 @@ pub fn transform_string(
632632
}
633633
}
634634

635+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
636+
#[pg_extern(immutable, parallel_safe, name = "transform")]
637+
#[allow(unused_variables)] // cache is maintained for api compatibility
638+
pub fn transform_conversational_json(
639+
task: JsonB,
640+
args: default!(JsonB, "'{}'"),
641+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
642+
cache: default!(bool, false),
643+
) -> JsonB {
644+
match crate::bindings::transformers::transform(&task.0, &args.0, inputs) {
645+
Ok(output) => JsonB(output),
646+
Err(e) => error!("{e}"),
647+
}
648+
}
649+
650+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
651+
#[pg_extern(immutable, parallel_safe, name = "transform")]
652+
#[allow(unused_variables)] // cache is maintained for api compatibility
653+
pub fn transform_conversational_string(
654+
task: String,
655+
args: default!(JsonB, "'{}'"),
656+
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
657+
cache: default!(bool, false),
658+
) -> JsonB {
659+
let task_json = json!({ "task": task });
660+
match crate::bindings::transformers::transform(&task_json, &args.0, inputs) {
661+
Ok(output) => JsonB(output),
662+
Err(e) => error!("{e}"),
663+
}
664+
}
665+
635666
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
636667
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
637668
#[allow(unused_variables)] // cache is maintained for api compatibility
@@ -642,10 +673,13 @@ pub fn transform_stream_json(
642673
cache: default!(bool, false),
643674
) -> SetOfIterator<'static, String> {
644675
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
645-
let python_iter =
646-
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input)
647-
.map_err(|e| error!("{e}"))
648-
.unwrap();
676+
let python_iter = crate::bindings::transformers::transform_stream_iterator(
677+
&task.0,
678+
&args.0,
679+
input.to_string(),
680+
)
681+
.map_err(|e| error!("{e}"))
682+
.unwrap();
649683
SetOfIterator::new(python_iter)
650684
}
651685

@@ -667,6 +701,51 @@ pub fn transform_stream_string(
667701
SetOfIterator::new(python_iter)
668702
}
669703

704+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
705+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
706+
#[allow(unused_variables)] // cache is maintained for api compatibility
707+
pub fn transform_stream_conversational_json(
708+
task: JsonB,
709+
args: default!(JsonB, "'{}'"),
710+
input: default!(JsonB, "'[]'::JSONB"),
711+
cache: default!(bool, false),
712+
) -> SetOfIterator<'static, String> {
713+
// If they have Vec<JsonB> inputs lets make sure they have the write task
714+
if !task.0["task"]
715+
.as_str()
716+
.is_some_and(|v| v == "conversational")
717+
{
718+
error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task");
719+
}
720+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
721+
let python_iter =
722+
crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input.0)
723+
.map_err(|e| error!("{e}"))
724+
.unwrap();
725+
SetOfIterator::new(python_iter)
726+
}
727+
728+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
729+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
730+
#[allow(unused_variables)] // cache is maintained for api compatibility
731+
pub fn transform_stream_conversational_string(
732+
task: String,
733+
args: default!(JsonB, "'{}'"),
734+
input: default!(JsonB, "'[]'::JSONB"),
735+
cache: default!(bool, false),
736+
) -> SetOfIterator<'static, String> {
737+
if task != "conversational" {
738+
error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task");
739+
}
740+
let task_json = json!({ "task": task });
741+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
742+
let python_iter =
743+
crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input.0)
744+
.map_err(|e| error!("{e}"))
745+
.unwrap();
746+
SetOfIterator::new(python_iter)
747+
}
748+
670749
#[cfg(feature = "python")]
671750
#[pg_extern(immutable, parallel_safe, name = "generate")]
672751
fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String {

pgml-extension/src/bindings/transformers/transform.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use anyhow::Result;
44
use pgrx::*;
55
use pyo3::prelude::*;
66
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
7+
use pyo3::AsPyPointer;
78

89
create_pymodule!("/src/bindings/transformers/transformers.py");
910

@@ -41,10 +42,10 @@ impl Iterator for TransformStreamIterator {
4142
}
4243
}
4344

44-
pub fn transform(
45+
pub fn transform<T: serde::Serialize>(
4546
task: &serde_json::Value,
4647
args: &serde_json::Value,
47-
inputs: Vec<&str>,
48+
inputs: T,
4849
) -> Result<serde_json::Value> {
4950
crate::bindings::python::activate()?;
5051
whitelist::verify_task(task)?;
@@ -74,17 +75,17 @@ pub fn transform(
7475
Ok(serde_json::from_str(&results)?)
7576
}
7677

77-
pub fn transform_stream(
78+
pub fn transform_stream<T: serde::Serialize>(
7879
task: &serde_json::Value,
7980
args: &serde_json::Value,
80-
input: &str,
81+
input: T,
8182
) -> Result<Py<PyAny>> {
8283
crate::bindings::python::activate()?;
8384
whitelist::verify_task(task)?;
8485

8586
let task = serde_json::to_string(task)?;
8687
let args = serde_json::to_string(args)?;
87-
let inputs = serde_json::to_string(&vec![input])?;
88+
let input = serde_json::to_string(&input)?;
8889

8990
Python::with_gil(|py| -> Result<Py<PyAny>> {
9091
let transform: Py<PyAny> = get_module!(PY_MODULE)
@@ -99,7 +100,7 @@ pub fn transform_stream(
99100
&[
100101
task.into_py(py),
101102
args.into_py(py),
102-
inputs.into_py(py),
103+
input.into_py(py),
103104
true.into_py(py),
104105
],
105106
),
@@ -110,10 +111,10 @@ pub fn transform_stream(
110111
})
111112
}
112113

113-
pub fn transform_stream_iterator(
114+
pub fn transform_stream_iterator<T: serde::Serialize>(
114115
task: &serde_json::Value,
115116
args: &serde_json::Value,
116-
input: &str,
117+
input: T,
117118
) -> Result<TransformStreamIterator> {
118119
let python_iter = transform_stream(task, args, input)
119120
.map_err(|e| error!("{e}"))

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
TrainingArguments,
4242
Trainer,
4343
TextStreamer,
44+
Conversation
4445
)
4546
from threading import Thread
4647
from typing import Optional
@@ -198,8 +199,8 @@ def __init__(self, model_name, **task):
198199
self.task = "text-generation"
199200

200201
def stream(self, inputs, **kwargs):
201-
output = self.model(inputs[0], stream=True, **kwargs)
202-
return ThreadedGeneratorIterator(output, inputs[0])
202+
output = self.model(inputs, stream=True, **kwargs)
203+
return ThreadedGeneratorIterator(output, inputs)
203204

204205
def __call__(self, inputs, **kwargs):
205206
outputs = []
@@ -224,6 +225,7 @@ def __init__(self, model_name, **kwargs):
224225
"summarization",
225226
"translation",
226227
"text-generation",
228+
"conversational"
227229
]
228230
):
229231
self.task = kwargs.pop("task")
@@ -238,7 +240,7 @@ def __init__(self, model_name, **kwargs):
238240
)
239241
elif self.task == "summarization" or self.task == "translation":
240242
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs)
241-
elif self.task == "text-generation":
243+
elif self.task == "text-generation" or self.task == "conversational":
242244
self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
243245
else:
244246
raise PgMLException(f"Unhandled task: {self.task}")
@@ -266,15 +268,30 @@ def __init__(self, model_name, **kwargs):
266268
self.tokenizer = self.pipe.tokenizer
267269

268270
def stream(self, inputs, **kwargs):
269-
streamer = TextIteratorStreamer(self.tokenizer)
270-
inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
271-
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
271+
streamer = None
272+
generation_kwargs = None
273+
if self.task == "conversational":
274+
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
275+
inputs = tokenized_chat = self.tokenizer.apply_chat_template(inputs, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
276+
generation_kwargs = dict(inputs=inputs, streamer=streamer, **kwargs)
277+
else:
278+
streamer = TextIteratorStreamer(self.tokenizer)
279+
inputs = self.tokenizer([inputs], return_tensors="pt").to(self.model.device)
280+
generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
272281
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
273282
thread.start()
274283
return streamer
275284

276285
def __call__(self, inputs, **kwargs):
277-
return self.pipe(inputs, **kwargs)
286+
if self.task == "conversational":
287+
outputs = []
288+
for conversation in inputs:
289+
conversation = Conversation(conversation)
290+
conversation = self.pipe(conversation, **kwargs)
291+
outputs.append(conversation.generated_responses[-1])
292+
return outputs
293+
else:
294+
return self.pipe(inputs, **kwargs)
278295

279296

280297
def get_model_from(task):

pgml-sdks/pgml/build.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#"
1414
export function init_logger(level?: string, format?: string): void;
1515
export function migrate(): Promise<void>;
1616
17-
export type Json = { [key: string]: any };
17+
export type Json = any;
1818
export type DateTime = Date;
1919
2020
export function newCollection(name: string, database_url?: string): Collection;
@@ -23,6 +23,7 @@ export function newSplitter(name?: string, parameters?: Json): Splitter;
2323
export function newBuiltins(database_url?: string): Builtins;
2424
export function newPipeline(name: string, model?: Model, splitter?: Splitter, parameters?: Json): Pipeline;
2525
export function newTransformerPipeline(task: string, model?: string, args?: Json, database_url?: string): TransformerPipeline;
26+
export function newOpenSourceAI(database_url?: string): OpenSourceAI;
2627
"#;
2728

2829
fn main() {

pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,29 @@ it("can transformer pipeline stream", async () => {
299299
output.push(result.value);
300300
result = await it.next();
301301
}
302-
expect(output.length).toBeGreaterThan(0)
302+
expect(output.length).toBeGreaterThan(0);
303+
});
304+
305+
///////////////////////////////////////////////////
306+
// Test OpenSourceAI //////////////////////////////
307+
///////////////////////////////////////////////////
308+
309+
it("can open source ai create", async () => {
310+
const client = pgml.newOpenSourceAI();
311+
const results = client.chat_completions_create(
312+
"mistralai/Mistral-7B-v0.1",
313+
[
314+
{
315+
role: "system",
316+
content: "You are a friendly chatbot who always responds in the style of a pirate",
317+
},
318+
{
319+
role: "user",
320+
content: "How many helicopters can a human eat in one sitting?",
321+
},
322+
],
323+
);
324+
expect(results.choices.length).toBeGreaterThan(0);
303325
});
304326

305327
///////////////////////////////////////////////////

pgml-sdks/pgml/python/tests/test.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ async def test_order_documents():
307307
async def test_transformer_pipeline():
308308
t = pgml.TransformerPipeline("text-generation")
309309
it = await t.transform(["AI is going to"], {"max_new_tokens": 5})
310-
assert (len(it)) > 0
310+
assert len(it) > 0
311+
311312

312313
@pytest.mark.asyncio
313314
async def test_transformer_pipeline_stream():
@@ -316,7 +317,31 @@ async def test_transformer_pipeline_stream():
316317
total = []
317318
async for c in it:
318319
total.append(c)
319-
assert (len(total)) > 0
320+
assert len(total) > 0
321+
322+
323+
###################################################
324+
## Transformer Pipeline Tests #####################
325+
###################################################
326+
327+
328+
def test_open_source_ai_create():
329+
client = pgml.OpenSourceAI()
330+
results = client.chat_completions_create(
331+
"mistralai/Mistral-7B-v0.1",
332+
[
333+
{
334+
"role": "system",
335+
"content": "You are a friendly chatbot who always responds in the style of a pirate",
336+
},
337+
{
338+
"role": "user",
339+
"content": "How many helicopters can a human eat in one sitting?",
340+
},
341+
],
342+
temperature=0.85
343+
)
344+
assert len(results["choices"]) > 0
320345

321346

322347
###################################################

pgml-sdks/pgml/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ mod languages;
1919
pub mod migrations;
2020
mod model;
2121
pub mod models;
22+
mod open_source_ai;
2223
mod order_by_builder;
2324
mod pipeline;
2425
mod queries;
@@ -34,6 +35,7 @@ mod utils;
3435
pub use builtins::Builtins;
3536
pub use collection::Collection;
3637
pub use model::Model;
38+
pub use open_source_ai::OpenSourceAI;
3739
pub use pipeline::Pipeline;
3840
pub use splitter::Splitter;
3941
pub use transformer_pipeline::TransformerPipeline;
@@ -152,6 +154,7 @@ fn pgml(_py: pyo3::Python, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> {
152154
m.add_class::<splitter::SplitterPython>()?;
153155
m.add_class::<builtins::BuiltinsPython>()?;
154156
m.add_class::<transformer_pipeline::TransformerPipelinePython>()?;
157+
m.add_class::<open_source_ai::OpenSourceAIPython>()?;
155158
Ok(())
156159
}
157160

@@ -201,6 +204,10 @@ fn main(mut cx: neon::context::ModuleContext) -> neon::result::NeonResult<()> {
201204
transformer_pipeline::TransformerPipelineJavascript::new,
202205
)?;
203206
cx.export_function("newPipeline", pipeline::PipelineJavascript::new)?;
207+
cx.export_function(
208+
"newOpenSourceAI",
209+
open_source_ai::OpenSourceAIJavascript::new,
210+
)?;
204211
Ok(())
205212
}
206213

0 commit comments

Comments
 (0)