Skip to content

Commit 2969c89

Browse files
committed
Cleaned up and tested well
1 parent f3d8e1f commit 2969c89

File tree

5 files changed

+107
-52
lines changed

5 files changed

+107
-52
lines changed

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def __iter__(self):
126126

127127
def __next__(self):
128128
value = self.text_queue.get(timeout=self.timeout)
129-
print("\n\n", value, "\n\n", file=sys.stderr)
130129
if value != self.stop_signal:
131130
return value
132131

@@ -286,9 +285,17 @@ def stream(self, input, **kwargs):
286285
streamer = TextIteratorStreamer(
287286
self.tokenizer, skip_prompt=True, skip_special_tokens=True
288287
)
289-
input = self.tokenizer.apply_chat_template(
290-
input, add_generation_prompt=True, tokenize=False
291-
)
288+
if "chat_template" in kwargs:
289+
input = self.tokenizer.apply_chat_template(
290+
input,
291+
add_generation_prompt=True,
292+
tokenize=False,
293+
chat_template=kwargs.pop("chat_template"),
294+
)
295+
else:
296+
input = self.tokenizer.apply_chat_template(
297+
input, add_generation_prompt=True, tokenize=False
298+
)
292299
input = self.tokenizer(input, return_tensors="pt").to(self.model.device)
293300
generation_kwargs = dict(input, streamer=streamer, **kwargs)
294301
else:
@@ -303,9 +310,17 @@ def stream(self, input, **kwargs):
303310

304311
def __call__(self, inputs, **kwargs):
305312
if self.task == "conversational":
306-
inputs = self.tokenizer.apply_chat_template(
307-
inputs, add_generation_prompt=True, tokenize=False
308-
)
313+
if "chat_template" in kwargs:
314+
inputs = self.tokenizer.apply_chat_template(
315+
inputs,
316+
add_generation_prompt=True,
317+
tokenize=False,
318+
chat_template=kwargs.pop("chat_template"),
319+
)
320+
else:
321+
inputs = self.tokenizer.apply_chat_template(
322+
inputs, add_generation_prompt=True, tokenize=False
323+
)
309324
inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device)
310325
args = dict(inputs, **kwargs)
311326
outputs = self.model.generate(**args)

pgml-sdks/pgml/src/builtins.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ mod tests {
101101
let query = "SELECT * from pgml.collections";
102102
let results = builtins.query(query).fetch_all().await?;
103103
assert!(results.as_array().is_some());
104-
Ok(())
104+
Ok(())
105105
}
106106

107107
#[sqlx::test]

pgml-sdks/pgml/src/languages/javascript.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,17 @@ impl IntoJsResult for PipelineSyncData {
7474
}
7575

7676
#[derive(Clone)]
77-
struct GeneralJsonAsyncIteratorArcMutex(Arc<tokio::sync::Mutex<GeneralJsonAsyncIterator>>);
77+
struct GeneralJsonAsyncIteratorJavaScript(Arc<tokio::sync::Mutex<GeneralJsonAsyncIterator>>);
7878

79-
impl Finalize for GeneralJsonAsyncIteratorArcMutex {}
79+
impl Finalize for GeneralJsonAsyncIteratorJavaScript {}
8080

8181
fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise> {
8282
let this = cx.this();
83-
let s: Handle<JsBox<GeneralJsonAsyncIteratorArcMutex>> = this
83+
let s: Handle<JsBox<GeneralJsonAsyncIteratorJavaScript>> = this
8484
.get(&mut cx, "s")
8585
.expect("Error getting self in transformer_stream_iterate_next");
86-
let ts: &GeneralJsonAsyncIteratorArcMutex = &s;
87-
let ts: GeneralJsonAsyncIteratorArcMutex = ts.clone();
86+
let ts: &GeneralJsonAsyncIteratorJavaScript = &s;
87+
let ts: GeneralJsonAsyncIteratorJavaScript = ts.clone();
8888

8989
let channel = cx.channel();
9090
let (deferred, promise) = cx.promise();
@@ -101,13 +101,13 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise>
101101
.expect("Error converting rust Json to JavaScript Object");
102102
let d = cx.boolean(false);
103103
o.set(&mut cx, "value", v)
104-
.expect("Error setting object value in transformer_sream_iterate_next");
104+
.expect("Error setting object value in transform_sream_iterate_next");
105105
o.set(&mut cx, "done", d)
106-
.expect("Error setting object value in transformer_sream_iterate_next");
106+
.expect("Error setting object value in transform_sream_iterate_next");
107107
} else {
108108
let d = cx.boolean(true);
109109
o.set(&mut cx, "done", d)
110-
.expect("Error setting object value in transformer_sream_iterate_next");
110+
.expect("Error setting object value in transform_sream_iterate_next");
111111
}
112112
Ok(o)
113113
})
@@ -125,7 +125,7 @@ impl IntoJsResult for GeneralJsonAsyncIterator {
125125
let o = cx.empty_object();
126126
let f: Handle<JsFunction> = JsFunction::new(cx, transform_stream_iterate_next)?;
127127
o.set(cx, "next", f)?;
128-
let s = cx.boxed(GeneralJsonAsyncIteratorArcMutex(Arc::new(
128+
let s = cx.boxed(GeneralJsonAsyncIteratorJavaScript(Arc::new(
129129
tokio::sync::Mutex::new(self),
130130
)));
131131
o.set(cx, "s", s)?;
@@ -141,7 +141,7 @@ fn transform_iterate_next(mut cx: FunctionContext) -> JsResult<JsObject> {
141141
let this = cx.this();
142142
let s: Handle<JsBox<GeneralJsonIteratorJavaScript>> = this
143143
.get(&mut cx, "s")
144-
.expect("Error getting self in transformer_stream_iterate_next");
144+
.expect("Error getting self in transform_iterate_next");
145145
let v = s.0.borrow_mut().next();
146146
let o = cx.empty_object();
147147
if let Some(v) = v {
@@ -151,13 +151,13 @@ fn transform_iterate_next(mut cx: FunctionContext) -> JsResult<JsObject> {
151151
.expect("Error converting rust Json to JavaScript Object");
152152
let d = cx.boolean(false);
153153
o.set(&mut cx, "value", v)
154-
.expect("Error setting object value in transformer_sream_iterate_next");
154+
.expect("Error setting object value in transform_iterate_next");
155155
o.set(&mut cx, "done", d)
156-
.expect("Error setting object value in transformer_sream_iterate_next");
156+
.expect("Error setting object value in transform_iterate_next");
157157
} else {
158158
let d = cx.boolean(true);
159159
o.set(&mut cx, "done", d)
160-
.expect("Error setting object value in transformer_sream_iterate_next");
160+
.expect("Error setting object value in transform_iterate_next");
161161
}
162162
Ok(o)
163163
}

pgml-sdks/pgml/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,6 @@ mod tests {
765765
.filter(filter)
766766
.fetch_all()
767767
.await?;
768-
println!("{:?}", results);
769768
assert_eq!(results.len(), expected_result_count);
770769
}
771770

pgml-sdks/pgml/src/open_source_ai.rs

Lines changed: 71 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
use anyhow::Context;
2-
use futures::{StreamExt, Stream};
2+
use futures::{Stream, StreamExt};
33
use rust_bridge::{alias, alias_methods};
44
use std::time::{SystemTime, UNIX_EPOCH};
55
use uuid::Uuid;
66

77
use crate::{
8-
types::{GeneralJsonAsyncIterator, Json, GeneralJsonIterator},
9-
TransformerPipeline, get_or_set_runtime,
8+
get_or_set_runtime,
9+
types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json},
10+
TransformerPipeline,
1011
};
1112

1213
#[cfg(feature = "python")]
13-
use crate::types::{JsonPython, GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython};
14+
use crate::types::{GeneralJsonAsyncIteratorPython, GeneralJsonIteratorPython, JsonPython};
1415

1516
#[derive(alias, Debug, Clone)]
1617
pub struct OpenSourceAI {
@@ -43,10 +44,26 @@ fn try_model_nice_name_to_model_name_and_parameters(
4344
})
4445
.into(),
4546
)),
47+
"PygmalionAI/mythalion-13b" => Some((
48+
"TheBloke/Mythalion-13B-GPTQ",
49+
serde_json::json!({
50+
"model": "TheBloke/Mythalion-13B-GPTQ",
51+
"device_map": "auto",
52+
"revision": "main"
53+
})
54+
.into(),
55+
)),
4656
_ => None,
4757
}
4858
}
4959

60+
fn try_get_model_chat_template(model_name: &str) -> Option<&'static str> {
61+
match model_name {
62+
"PygmalionAI/mythalion-13b" => Some("{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'model' %}\n{{ '<|model|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|model|>' }}\n{% endif %}\n{% endfor %}"),
63+
_ => None
64+
}
65+
}
66+
5067
struct AsyncToSyncJsonIterator(std::pin::Pin<Box<dyn Stream<Item = anyhow::Result<Json>> + Send>>);
5168

5269
impl Iterator for AsyncToSyncJsonIterator {
@@ -58,7 +75,13 @@ impl Iterator for AsyncToSyncJsonIterator {
5875
}
5976
}
6077

61-
#[alias_methods(new, chat_completions_create, chat_completions_create_async, chat_completions_create_stream, chat_completions_create_stream_async)]
78+
#[alias_methods(
79+
new,
80+
chat_completions_create,
81+
chat_completions_create_async,
82+
chat_completions_create_stream,
83+
chat_completions_create_stream_async
84+
)]
6285
impl OpenSourceAI {
6386
pub fn new(database_url: Option<String>) -> Self {
6487
Self { database_url }
@@ -114,6 +137,7 @@ mistralai/Mistral-7B-v0.1
114137
max_tokens: Option<i32>,
115138
temperature: Option<f64>,
116139
n: Option<i32>,
140+
chat_template: Option<String>,
117141
) -> anyhow::Result<GeneralJsonAsyncIterator> {
118142
let (transformer_pipeline, model_name, model_parameters) =
119143
self.create_pipeline_model_name_parameters(model)?;
@@ -125,24 +149,26 @@ mistralai/Mistral-7B-v0.1
125149
let md5_digest = md5::compute(to_hash.as_bytes());
126150
let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?;
127151

152+
let mut args = serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n });
153+
if let Some(t) = chat_template
154+
.or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string()))
155+
{
156+
args.as_object_mut().unwrap().insert(
157+
"chat_template".to_string(),
158+
serde_json::to_value(t).unwrap(),
159+
);
160+
}
161+
128162
let messages = serde_json::to_value(messages)?.into();
129163
let iterator = transformer_pipeline
130-
.transform_stream(
131-
messages,
132-
Some(
133-
serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n })
134-
.into(),
135-
),
136-
Some(1)
137-
)
164+
.transform_stream(messages, Some(args.into()), Some(1))
138165
.await?;
139166

140167
let id = Uuid::new_v4().to_string();
141168
let iter = iterator.map(move |choices| {
142169
let since_the_epoch = SystemTime::now()
143170
.duration_since(UNIX_EPOCH)
144171
.expect("Time went backwards");
145-
eprintln!("{:?}", choices);
146172
Ok(serde_json::json!({
147173
"id": id.clone(),
148174
"system_fingerprint": fingerprint.clone(),
@@ -155,9 +181,8 @@ mistralai/Mistral-7B-v0.1
155181
"delta": {
156182
"role": "assistant",
157183
"content": c
158-
}
184+
}
159185
})
160-
// finish_reason goes here
161186
}).collect::<serde_json::Value>()
162187
})
163188
.into())
@@ -172,11 +197,21 @@ mistralai/Mistral-7B-v0.1
172197
messages: Vec<Json>,
173198
max_tokens: Option<i32>,
174199
temperature: Option<f64>,
200+
chat_template: Option<String>,
175201
n: Option<i32>,
176202
) -> anyhow::Result<GeneralJsonIterator> {
177203
let runtime = crate::get_or_set_runtime();
178-
let iter = runtime.block_on(self.chat_completions_create_stream_async(model, messages, max_tokens, temperature, n))?;
179-
Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator(Box::pin(iter)))))
204+
let iter = runtime.block_on(self.chat_completions_create_stream_async(
205+
model,
206+
messages,
207+
max_tokens,
208+
temperature,
209+
n,
210+
chat_template,
211+
))?;
212+
Ok(GeneralJsonIterator(Box::new(AsyncToSyncJsonIterator(
213+
Box::pin(iter),
214+
))))
180215
}
181216

182217
pub async fn chat_completions_create_async(
@@ -186,6 +221,7 @@ mistralai/Mistral-7B-v0.1
186221
max_tokens: Option<i32>,
187222
temperature: Option<f64>,
188223
n: Option<i32>,
224+
chat_template: Option<String>,
189225
) -> anyhow::Result<Json> {
190226
let (transformer_pipeline, model_name, model_parameters) =
191227
self.create_pipeline_model_name_parameters(model)?;
@@ -197,14 +233,18 @@ mistralai/Mistral-7B-v0.1
197233
let md5_digest = md5::compute(to_hash.as_bytes());
198234
let fingerprint = uuid::Uuid::from_slice(&md5_digest.0)?;
199235

236+
let mut args = serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n });
237+
if let Some(t) = chat_template
238+
.or_else(|| try_get_model_chat_template(&model_name).map(|s| s.to_string()))
239+
{
240+
args.as_object_mut().unwrap().insert(
241+
"chat_template".to_string(),
242+
serde_json::to_value(t).unwrap(),
243+
);
244+
}
245+
200246
let choices = transformer_pipeline
201-
.transform(
202-
messages,
203-
Some(
204-
serde_json::json!({ "max_length": max_tokens, "temperature": temperature, "do_sample": true, "num_return_sequences": n })
205-
.into(),
206-
),
207-
)
247+
.transform(messages, Some(args.into()))
208248
.await?;
209249
let choices: Vec<Json> = choices
210250
.as_array()
@@ -249,6 +289,7 @@ mistralai/Mistral-7B-v0.1
249289
max_tokens: Option<i32>,
250290
temperature: Option<f64>,
251291
n: Option<i32>,
292+
chat_template: Option<String>,
252293
) -> anyhow::Result<Json> {
253294
let runtime = crate::get_or_set_runtime();
254295
runtime.block_on(self.chat_completions_create_async(
@@ -257,6 +298,7 @@ mistralai/Mistral-7B-v0.1
257298
max_tokens,
258299
temperature,
259300
n,
301+
chat_template,
260302
))
261303
}
262304
}
@@ -272,7 +314,7 @@ mod tests {
272314
let results = client.chat_completions_create(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![
273315
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
274316
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
275-
], Some(10), None, Some(3))?;
317+
], Some(10), None, Some(3), None)?;
276318
assert!(results["choices"].as_array().is_some());
277319
Ok(())
278320
}
@@ -283,7 +325,7 @@ mod tests {
283325
let results = client.chat_completions_create_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![
284326
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
285327
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
286-
], Some(10), None, Some(3)).await?;
328+
], Some(10), None, Some(3), None).await?;
287329
assert!(results["choices"].as_array().is_some());
288330
Ok(())
289331
}
@@ -294,7 +336,7 @@ mod tests {
294336
let mut stream = client.chat_completions_create_stream_async(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![
295337
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
296338
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
297-
], Some(10), None, Some(3)).await?;
339+
], Some(10), None, Some(3), None).await?;
298340
while let Some(o) = stream.next().await {
299341
o?;
300342
}
@@ -307,11 +349,10 @@ mod tests {
307349
let iterator = client.chat_completions_create_stream(Json::from_serializable("mistralai/Mistral-7B-v0.1"), vec![
308350
serde_json::json!({"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}).into(),
309351
serde_json::json!({"role": "user", "content": "How many helicopters can a human eat in one sitting?"}).into(),
310-
], Some(10), None, Some(3))?;
352+
], Some(10), None, Some(3), None)?;
311353
for o in iterator {
312354
o?;
313355
}
314356
Ok(())
315357
}
316-
317358
}

0 commit comments

Comments
 (0)