Skip to content

Commit 23ef2a3

Browse files
committed
Working non streaming open source ai replacement
1 parent 719fdc5 commit 23ef2a3

File tree

6 files changed

+54
-37
lines changed

6 files changed

+54
-37
lines changed

pgml-extension/src/bindings/transformers/transform.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ impl Iterator for TransformStreamIterator {
3232
if res.is_none() {
3333
Ok(None)
3434
} else {
35+
eprintln!("\nHERE WE ARE!\n");
3536
let res: Vec<String> = res.extract()?;
37+
eprintln!("\nYUP WE DIDNT GET HERE\n");
3638
Ok(Some(JsonB(serde_json::to_value(res).unwrap())))
3739
}
3840
})

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __iter__(self):
126126

127127
def __next__(self):
128128
value = self.text_queue.get(timeout=self.timeout)
129+
print("\n\n", value, "\n\n", file=sys.stderr)
129130
if value != self.stop_signal:
130131
return value
131132

pgml-sdks/pgml/src/languages/javascript.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,10 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise>
9595
.try_settle_with(&channel, move |mut cx| {
9696
let o = cx.empty_object();
9797
if let Some(v) = v {
98-
let v: String = v.expect("Error calling next on TransformerStream");
99-
let v = cx.string(v);
98+
let v: Json = v.expect("Error calling next on TransformerStream");
99+
let v = v
100+
.into_js_result(&mut cx)
101+
.expect("Error converting rust Json to JavaScript Object");
100102
let d = cx.boolean(false);
101103
o.set(&mut cx, "value", v)
102104
.expect("Error setting object value in transformer_sream_iterate_next");

pgml-sdks/pgml/src/languages/python.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ impl TransformerStreamPython {
7272
if let Some(o) = ts.next().await {
7373
Ok(Some(Python::with_gil(|py| {
7474
o.expect("Error calling next on TransformerStream")
75-
.to_object(py)
75+
.into_py(py)
7676
})))
7777
} else {
7878
Err(pyo3::exceptions::PyStopAsyncIteration::new_err(

pgml-sdks/pgml/src/open_source_ai.rs

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,17 @@ impl OpenSourceAI {
4949
Self { database_url }
5050
}
5151

52-
pub async fn chat_completions_create_async(
52+
fn create_pipeline_model_name_parameters(
5353
&self,
5454
mut model: Json,
55-
messages: Json,
56-
max_tokens: Option<i32>,
57-
temperature: Option<f64>,
58-
n: Option<i32>,
59-
) -> anyhow::Result<Json> {
60-
let (transformer_pipeline, model_name, model_parameters) = if model.is_object() {
55+
) -> anyhow::Result<(TransformerPipeline, String, Json)> {
56+
if model.is_object() {
6157
let args = model.as_object_mut().unwrap();
6258
let model_name = args
6359
.remove("model")
6460
.context("`model` is a required key in the model object")?;
6561
let model_name = model_name.as_str().context("`model` must be a string")?;
66-
(
62+
Ok((
6763
TransformerPipeline::new(
6864
"conversational",
6965
Some(model_name.to_string()),
@@ -72,7 +68,7 @@ impl OpenSourceAI {
7268
),
7369
model_name.to_string(),
7470
model,
75-
)
71+
))
7672
} else {
7773
let model_name = model
7874
.as_str()
@@ -83,7 +79,7 @@ impl OpenSourceAI {
8379
mistralai/Mistral-7B-v0.1
8480
"#,
8581
)?;
86-
(
82+
Ok((
8783
TransformerPipeline::new(
8884
"conversational",
8985
Some(real_model_name.to_string()),
@@ -92,36 +88,52 @@ mistralai/Mistral-7B-v0.1
9288
),
9389
model_name.to_string(),
9490
parameters,
95-
)
96-
};
91+
))
92+
}
93+
}
94+
95+
pub async fn chat_completions_create_stream_async(
96+
&self,
97+
model: Json,
98+
messages: Vec<Json>,
99+
max_tokens: Option<i32>,
100+
temperature: Option<f64>,
101+
n: Option<i32>,
102+
) -> anyhow::Result<()> {
103+
Ok(())
104+
}
105+
106+
pub async fn chat_completions_create_async(
107+
&self,
108+
model: Json,
109+
messages: Vec<Json>,
110+
max_tokens: Option<i32>,
111+
temperature: Option<f64>,
112+
n: Option<i32>,
113+
) -> anyhow::Result<Json> {
114+
let (transformer_pipeline, model_name, model_parameters) =
115+
self.create_pipeline_model_name_parameters(model)?;
97116

98117
let max_tokens = max_tokens.unwrap_or(1000);
99118
let temperature = temperature.unwrap_or(0.8);
100119
let n = n.unwrap_or(1) as usize;
101-
let to_hash = format!(
102-
"{}{}{}{}",
103-
model_parameters.to_string(),
104-
max_tokens,
105-
temperature,
106-
n
107-
);
120+
let to_hash = format!("{}{}{}{}", *model_parameters, max_tokens, temperature, n);
108121
let md5_digest = md5::compute(to_hash.as_bytes());
109122
let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?;
110123

111-
let messages: Vec<Json> = std::iter::repeat(messages).take(n).collect();
112124
let choices = transformer_pipeline
113125
.transform(
114126
messages,
115127
Some(
116-
serde_json::json!({ "max_length": max_tokens, "temperature": temperature })
128+
serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n })
117129
.into(),
118130
),
119131
)
120132
.await?;
121133
let choices: Vec<Json> = choices
122134
.as_array()
123135
.context("Error parsing return from TransformerPipeline")?
124-
.into_iter()
136+
.iter()
125137
.enumerate()
126138
.map(|(i, c)| {
127139
serde_json::json!({
@@ -157,7 +169,7 @@ mistralai/Mistral-7B-v0.1
157169
pub fn chat_completions_create(
158170
&self,
159171
model: Json,
160-
messages: Json,
172+
messages: Vec<Json>,
161173
max_tokens: Option<i32>,
162174
temperature: Option<f64>,
163175
n: Option<i32>,
@@ -177,14 +189,14 @@ mistralai/Mistral-7B-v0.1
177189
mod tests {
178190
use super::*;
179191

180-
#[sqlx::test]
181-
async fn can_open_source_ai_create() -> anyhow::Result<()> {
192+
#[test]
193+
fn can_open_source_ai_create() -> anyhow::Result<()> {
182194
let client = OpenSourceAI::new(None);
183-
let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), serde_json::json!([
184-
{"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"},
185-
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"}
186-
]).into(), Some(1000), None, None)?;
187-
assert!(results.as_array().is_some());
195+
let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![
196+
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
197+
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
198+
], Some(1000), None, Some(3))?;
199+
assert!(results["choices"].as_array().is_some());
188200
Ok(())
189201
}
190202
}

pgml-sdks/pgml/src/transformer_pipeline.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl TransformerStream {
5555
}
5656

5757
impl Stream for TransformerStream {
58-
type Item = Result<String, sqlx::Error>;
58+
type Item = Result<Json, sqlx::Error>;
5959

6060
fn poll_next(
6161
mut self: Pin<&mut Self>,
@@ -106,7 +106,7 @@ impl Stream for TransformerStream {
106106

107107
if !self.results.is_empty() {
108108
let r = self.results.pop_front().unwrap();
109-
Poll::Ready(Some(Ok(r.get::<String, _>(0))))
109+
Poll::Ready(Some(Ok(r.get::<Json, _>(0))))
110110
} else if self.done {
111111
Poll::Ready(None)
112112
} else {
@@ -251,10 +251,10 @@ mod tests {
251251
internal_init_logger(None, None).ok();
252252
let t = TransformerPipeline::new(
253253
"text-generation",
254-
Some("TheBloke/zephyr-7B-beta-GGUF".to_string()),
254+
Some("TheBloke/zephyr-7B-beta-GPTQ".to_string()),
255255
Some(
256256
serde_json::json!({
257-
"model_file": "zephyr-7b-beta.Q5_K_M.gguf", "model_type": "mistral"
257+
"model_type": "mistral", "revision": "main", "device_map": "auto"
258258
})
259259
.into(),
260260
),

0 commit comments

Comments
 (0)