Skip to content

Commit 3e8cc28

Browse files
authored
Add streaming (#1145)
1 parent 37a888f commit 3e8cc28

File tree

4 files changed

+278
-58
lines changed

4 files changed

+278
-58
lines changed

pgml-extension/src/api.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ use std::str::FromStr;
44
use ndarray::Zip;
55
use pgrx::iter::{SetOfIterator, TableIterator};
66
use pgrx::*;
7+
use pyo3::prelude::*;
8+
use pyo3::types::{IntoPyDict, PyDict};
79

810
#[cfg(feature = "python")]
911
use serde_json::json;
@@ -632,6 +634,75 @@ pub fn transform_string(
632634
}
633635
}
634636

637+
struct TransformStreamIterator {
638+
locals: Py<PyDict>,
639+
}
640+
641+
impl TransformStreamIterator {
642+
fn new(python_iter: Py<PyAny>) -> Self {
643+
let locals = Python::with_gil(|py| -> Result<Py<PyDict>, PyErr> {
644+
Ok([("python_iter", python_iter)].into_py_dict(py).into())
645+
})
646+
.map_err(|e| error!("{e}"))
647+
.unwrap();
648+
Self { locals }
649+
}
650+
}
651+
652+
impl Iterator for TransformStreamIterator {
653+
type Item = String;
654+
fn next(&mut self) -> Option<Self::Item> {
655+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
656+
Python::with_gil(|py| -> Result<Option<String>, PyErr> {
657+
let code = "next(python_iter)";
658+
let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?;
659+
if res.is_none() {
660+
Ok(None)
661+
} else {
662+
let res: String = res.extract()?;
663+
Ok(Some(res))
664+
}
665+
})
666+
.map_err(|e| error!("{e}"))
667+
.unwrap()
668+
}
669+
}
670+
671+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
672+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
673+
#[allow(unused_variables)] // cache is maintained for api compatibility
674+
pub fn transform_stream_json(
675+
task: JsonB,
676+
args: default!(JsonB, "'{}'"),
677+
input: default!(&str, "''"),
678+
cache: default!(bool, false),
679+
) -> SetOfIterator<'static, String> {
680+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
681+
let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input)
682+
.map_err(|e| error!("{e}"))
683+
.unwrap();
684+
let res = TransformStreamIterator::new(python_iter);
685+
SetOfIterator::new(res)
686+
}
687+
688+
#[cfg(all(feature = "python", not(feature = "use_as_lib")))]
689+
#[pg_extern(immutable, parallel_safe, name = "transform_stream")]
690+
#[allow(unused_variables)] // cache is maintained for api compatibility
691+
pub fn transform_stream_string(
692+
task: String,
693+
args: default!(JsonB, "'{}'"),
694+
input: default!(&str, "''"),
695+
cache: default!(bool, false),
696+
) -> SetOfIterator<'static, String> {
697+
let task_json = json!({ "task": task });
698+
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
699+
let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input)
700+
.map_err(|e| error!("{e}"))
701+
.unwrap();
702+
let res = TransformStreamIterator::new(python_iter);
703+
SetOfIterator::new(res)
704+
}
705+
635706
#[cfg(feature = "python")]
636707
#[pg_extern(immutable, parallel_safe, name = "generate")]
637708
fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String {

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,10 @@ use super::TracebackError;
1616

1717
pub mod whitelist;
1818

19-
create_pymodule!("/src/bindings/transformers/transformers.py");
20-
21-
pub fn transform(
22-
task: &serde_json::Value,
23-
args: &serde_json::Value,
24-
inputs: Vec<&str>,
25-
) -> Result<serde_json::Value> {
26-
crate::bindings::python::activate()?;
27-
28-
whitelist::verify_task(task)?;
29-
30-
let task = serde_json::to_string(task)?;
31-
let args = serde_json::to_string(args)?;
32-
let inputs = serde_json::to_string(&inputs)?;
19+
mod transformers;
20+
pub use transformers::*;
3321

34-
let results = Python::with_gil(|py| -> Result<String> {
35-
let transform: Py<PyAny> = get_module!(PY_MODULE)
36-
.getattr(py, "transform")
37-
.format_traceback(py)?;
38-
39-
let output = transform
40-
.call1(
41-
py,
42-
PyTuple::new(
43-
py,
44-
&[task.into_py(py), args.into_py(py), inputs.into_py(py)],
45-
),
46-
)
47-
.format_traceback(py)?;
48-
49-
output.extract(py).format_traceback(py)
50-
})?;
51-
52-
Ok(serde_json::from_str(&results)?)
53-
}
22+
create_pymodule!("/src/bindings/transformers/transformers.py");
5423

5524
pub fn get_model_from(task: &Value) -> Result<String> {
5625
Python::with_gil(|py| -> Result<String> {

0 commit comments

Comments
 (0)