Skip to content

Commit f3d8e1f

Browse files
committed
Working OpenSourceAI with both sync and async options
1 parent dedf434 commit f3d8e1f

File tree

8 files changed

+433
-50
lines changed

8 files changed

+433
-50
lines changed

pgml-sdks/pgml/javascript/tests/jest.config.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ export default {
44
roots: ['<rootDir>'],
55
transform: {
66
'^.+\\.tsx?$': 'ts-jest'
7-
}
7+
},
8+
testTimeout: 30000,
89
}

pgml-sdks/pgml/javascript/tests/typescript-tests/test.ts

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ it("can transformer pipeline stream", async () => {
306306
// Test OpenSourceAI //////////////////////////////
307307
///////////////////////////////////////////////////
308308

309-
it("can open source ai create", async () => {
309+
it("can open source ai create", () => {
310310
const client = pgml.newOpenSourceAI();
311311
const results = client.chat_completions_create(
312312
"mistralai/Mistral-7B-v0.1",
@@ -324,6 +324,70 @@ it("can open source ai create", async () => {
324324
expect(results.choices.length).toBeGreaterThan(0);
325325
});
326326

327+
328+
it("can open source ai create async", async () => {
329+
const client = pgml.newOpenSourceAI();
330+
const results = await client.chat_completions_create_async(
331+
"mistralai/Mistral-7B-v0.1",
332+
[
333+
{
334+
role: "system",
335+
content: "You are a friendly chatbot who always responds in the style of a pirate",
336+
},
337+
{
338+
role: "user",
339+
content: "How many helicopters can a human eat in one sitting?",
340+
},
341+
],
342+
);
343+
expect(results.choices.length).toBeGreaterThan(0);
344+
});
345+
346+
347+
it("can open source ai create stream", () => {
348+
const client = pgml.newOpenSourceAI();
349+
const it = client.chat_completions_create_stream(
350+
"mistralai/Mistral-7B-v0.1",
351+
[
352+
{
353+
role: "system",
354+
content: "You are a friendly chatbot who always responds in the style of a pirate",
355+
},
356+
{
357+
role: "user",
358+
content: "How many helicopters can a human eat in one sitting?",
359+
},
360+
],
361+
);
362+
let result = it.next();
363+
while (!result.done) {
364+
expect(result.value.choices.length).toBeGreaterThan(0);
365+
result = it.next();
366+
}
367+
});
368+
369+
it("can open source ai create stream async", async () => {
370+
const client = pgml.newOpenSourceAI();
371+
const it = await client.chat_completions_create_stream_async(
372+
"mistralai/Mistral-7B-v0.1",
373+
[
374+
{
375+
role: "system",
376+
content: "You are a friendly chatbot who always responds in the style of a pirate",
377+
},
378+
{
379+
role: "user",
380+
content: "How many helicopters can a human eat in one sitting?",
381+
},
382+
],
383+
);
384+
let result = await it.next();
385+
while (!result.done) {
386+
expect(result.value.choices.length).toBeGreaterThan(0);
387+
result = await it.next();
388+
}
389+
});
390+
327391
///////////////////////////////////////////////////
328392
// Test migrations ////////////////////////////////
329393
///////////////////////////////////////////////////

pgml-sdks/pgml/python/tests/test.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ async def test_transformer_pipeline_stream():
321321

322322

323323
###################################################
324-
## Transformer Pipeline Tests #####################
324+
## OpenSourceAI tests ###########################
325325
###################################################
326326

327327

@@ -339,11 +339,76 @@ def test_open_source_ai_create():
339339
"content": "How many helicopters can a human eat in one sitting?",
340340
},
341341
],
342-
temperature=0.85
342+
temperature=0.85,
343343
)
344344
assert len(results["choices"]) > 0
345345

346346

347+
@pytest.mark.asyncio
348+
async def test_open_source_ai_create_async():
349+
client = pgml.OpenSourceAI()
350+
results = await client.chat_completions_create_async(
351+
"mistralai/Mistral-7B-v0.1",
352+
[
353+
{
354+
"role": "system",
355+
"content": "You are a friendly chatbot who always responds in the style of a pirate",
356+
},
357+
{
358+
"role": "user",
359+
"content": "How many helicopters can a human eat in one sitting?",
360+
},
361+
],
362+
temperature=0.85,
363+
)
364+
import json
365+
assert len(results["choices"]) > 0
366+
367+
368+
def test_open_source_ai_create_stream():
369+
client = pgml.OpenSourceAI()
370+
results = client.chat_completions_create_stream(
371+
"mistralai/Mistral-7B-v0.1",
372+
[
373+
{
374+
"role": "system",
375+
"content": "You are a friendly chatbot who always responds in the style of a pirate",
376+
},
377+
{
378+
"role": "user",
379+
"content": "How many helicopters can a human eat in one sitting?",
380+
},
381+
],
382+
temperature=0.85,
383+
n=3,
384+
)
385+
for c in results:
386+
assert len(c["choices"]) > 0
387+
388+
389+
@pytest.mark.asyncio
390+
async def test_open_source_ai_create_stream_async():
391+
client = pgml.OpenSourceAI()
392+
results = await client.chat_completions_create_stream_async(
393+
"mistralai/Mistral-7B-v0.1",
394+
[
395+
{
396+
"role": "system",
397+
"content": "You are a friendly chatbot who always responds in the style of a pirate",
398+
},
399+
{
400+
"role": "user",
401+
"content": "How many helicopters can a human eat in one sitting?",
402+
},
403+
],
404+
temperature=0.85,
405+
n=3,
406+
)
407+
import json
408+
async for c in results:
409+
assert len(c["choices"]) > 0
410+
411+
347412
###################################################
348413
## Migration tests ################################
349414
###################################################

pgml-sdks/pgml/src/languages/javascript.rs

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use futures::StreamExt;
22
use neon::prelude::*;
33
use rust_bridge::javascript::{FromJsType, IntoJsResult};
4+
use std::cell::RefCell;
45
use std::sync::Arc;
56

67
use crate::{
78
pipeline::PipelineSyncData,
8-
transformer_pipeline::TransformerStream,
9-
types::{DateTime, Json},
9+
types::{DateTime, GeneralJsonAsyncIterator, GeneralJsonIterator, Json},
1010
};
1111

1212
////////////////////////////////////////////////////////////////////////////////
@@ -74,17 +74,17 @@ impl IntoJsResult for PipelineSyncData {
7474
}
7575

7676
#[derive(Clone)]
77-
struct TransformerStreamArcMutex(Arc<tokio::sync::Mutex<TransformerStream>>);
77+
struct GeneralJsonAsyncIteratorArcMutex(Arc<tokio::sync::Mutex<GeneralJsonAsyncIterator>>);
7878

79-
impl Finalize for TransformerStreamArcMutex {}
79+
impl Finalize for GeneralJsonAsyncIteratorArcMutex {}
8080

8181
fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise> {
8282
let this = cx.this();
83-
let s: Handle<JsBox<TransformerStreamArcMutex>> = this
83+
let s: Handle<JsBox<GeneralJsonAsyncIteratorArcMutex>> = this
8484
.get(&mut cx, "s")
8585
.expect("Error getting self in transformer_stream_iterate_next");
86-
let ts: &TransformerStreamArcMutex = &s;
87-
let ts: TransformerStreamArcMutex = ts.clone();
86+
let ts: &GeneralJsonAsyncIteratorArcMutex = &s;
87+
let ts: GeneralJsonAsyncIteratorArcMutex = ts.clone();
8888

8989
let channel = cx.channel();
9090
let (deferred, promise) = cx.promise();
@@ -95,7 +95,7 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise>
9595
.try_settle_with(&channel, move |mut cx| {
9696
let o = cx.empty_object();
9797
if let Some(v) = v {
98-
let v: Json = v.expect("Error calling next on TransformerStream");
98+
let v: Json = v.expect("Error calling next on GeneralJsonAsyncIterator");
9999
let v = v
100100
.into_js_result(&mut cx)
101101
.expect("Error converting rust Json to JavaScript Object");
@@ -116,20 +116,64 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult<JsPromise>
116116
Ok(promise)
117117
}
118118

119-
impl IntoJsResult for TransformerStream {
120-
type Output = JsObject;
119+
impl IntoJsResult for GeneralJsonAsyncIterator {
120+
type Output = JsValue;
121121
fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>(
122122
self,
123123
cx: &mut C,
124124
) -> JsResult<'b, Self::Output> {
125125
let o = cx.empty_object();
126126
let f: Handle<JsFunction> = JsFunction::new(cx, transform_stream_iterate_next)?;
127127
o.set(cx, "next", f)?;
128-
let s = cx.boxed(TransformerStreamArcMutex(Arc::new(
128+
let s = cx.boxed(GeneralJsonAsyncIteratorArcMutex(Arc::new(
129129
tokio::sync::Mutex::new(self),
130130
)));
131131
o.set(cx, "s", s)?;
132-
Ok(o)
132+
Ok(o.as_value(cx))
133+
}
134+
}
135+
136+
struct GeneralJsonIteratorJavaScript(RefCell<GeneralJsonIterator>);
137+
138+
impl Finalize for GeneralJsonIteratorJavaScript {}
139+
140+
fn transform_iterate_next(mut cx: FunctionContext) -> JsResult<JsObject> {
141+
let this = cx.this();
142+
let s: Handle<JsBox<GeneralJsonIteratorJavaScript>> = this
143+
.get(&mut cx, "s")
144+
.expect("Error getting self in transformer_stream_iterate_next");
145+
let v = s.0.borrow_mut().next();
146+
let o = cx.empty_object();
147+
if let Some(v) = v {
148+
let v: Json = v.expect("Error calling next on GeneralJsonAsyncIterator");
149+
let v = v
150+
.into_js_result(&mut cx)
151+
.expect("Error converting rust Json to JavaScript Object");
152+
let d = cx.boolean(false);
153+
o.set(&mut cx, "value", v)
154+
.expect("Error setting object value in transformer_sream_iterate_next");
155+
o.set(&mut cx, "done", d)
156+
.expect("Error setting object value in transformer_sream_iterate_next");
157+
} else {
158+
let d = cx.boolean(true);
159+
o.set(&mut cx, "done", d)
160+
.expect("Error setting object value in transformer_sream_iterate_next");
161+
}
162+
Ok(o)
163+
}
164+
165+
impl IntoJsResult for GeneralJsonIterator {
166+
type Output = JsValue;
167+
fn into_js_result<'a, 'b, 'c: 'b, C: Context<'c>>(
168+
self,
169+
cx: &mut C,
170+
) -> JsResult<'b, Self::Output> {
171+
let o = cx.empty_object();
172+
let f: Handle<JsFunction> = JsFunction::new(cx, transform_iterate_next)?;
173+
o.set(cx, "next", f)?;
174+
let s = cx.boxed(GeneralJsonIteratorJavaScript(RefCell::new(self)));
175+
o.set(cx, "s", s)?;
176+
Ok(o.as_value(cx))
133177
}
134178
}
135179

pgml-sdks/pgml/src/languages/python.rs

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ use std::sync::Arc;
66

77
use rust_bridge::python::CustomInto;
88

9-
use crate::{pipeline::PipelineSyncData, transformer_pipeline::TransformerStream, types::Json};
9+
use crate::{
10+
pipeline::PipelineSyncData,
11+
types::{GeneralJsonAsyncIterator, GeneralJsonIterator, Json},
12+
};
1013

1114
////////////////////////////////////////////////////////////////////////////////
1215
// Rust to PY //////////////////////////////////////////////////////////////////
@@ -55,12 +58,12 @@ impl IntoPy<PyObject> for PipelineSyncData {
5558

5659
#[pyclass]
5760
#[derive(Clone)]
58-
struct TransformerStreamPython {
59-
wrapped: Arc<tokio::sync::Mutex<TransformerStream>>,
61+
struct GeneralJsonAsyncIteratorPython {
62+
wrapped: Arc<tokio::sync::Mutex<GeneralJsonAsyncIterator>>,
6063
}
6164

6265
#[pymethods]
63-
impl TransformerStreamPython {
66+
impl GeneralJsonAsyncIteratorPython {
6467
fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
6568
slf
6669
}
@@ -71,7 +74,7 @@ impl TransformerStreamPython {
7174
let mut ts = ts.lock().await;
7275
if let Some(o) = ts.next().await {
7376
Ok(Some(Python::with_gil(|py| {
74-
o.expect("Error calling next on TransformerStream")
77+
o.expect("Error calling next on GeneralJsonAsyncIterator")
7578
.into_py(py)
7679
})))
7780
} else {
@@ -84,15 +87,47 @@ impl TransformerStreamPython {
8487
}
8588
}
8689

87-
impl IntoPy<PyObject> for TransformerStream {
90+
impl IntoPy<PyObject> for GeneralJsonAsyncIterator {
8891
fn into_py(self, py: Python) -> PyObject {
89-
let f: Py<TransformerStreamPython> = Py::new(
92+
let f: Py<GeneralJsonAsyncIteratorPython> = Py::new(
9093
py,
91-
TransformerStreamPython {
94+
GeneralJsonAsyncIteratorPython {
9295
wrapped: Arc::new(tokio::sync::Mutex::new(self)),
9396
},
9497
)
95-
.expect("Error converting TransformerStream to TransformerStreamPython");
98+
.expect("Error converting GeneralJsonAsyncIterator to GeneralJsonAsyncIteratorPython");
99+
f.to_object(py)
100+
}
101+
}
102+
103+
#[pyclass]
104+
struct GeneralJsonIteratorPython {
105+
wrapped: GeneralJsonIterator,
106+
}
107+
108+
#[pymethods]
109+
impl GeneralJsonIteratorPython {
110+
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
111+
slf
112+
}
113+
114+
fn __next__(mut slf: PyRefMut<'_, Self>, py: Python) -> PyResult<Option<PyObject>> {
115+
if let Some(o) = slf.wrapped.next() {
116+
let o = o.expect("Error calling next on GeneralJsonIterator");
117+
Ok(Some(o.into_py(py)))
118+
} else {
119+
Err(pyo3::exceptions::PyStopIteration::new_err(
120+
"stream exhausted",
121+
))
122+
}
123+
}
124+
}
125+
126+
impl IntoPy<PyObject> for GeneralJsonIterator {
127+
fn into_py(self, py: Python) -> PyObject {
128+
let f: Py<GeneralJsonIteratorPython> =
129+
Py::new(py, GeneralJsonIteratorPython { wrapped: self })
130+
.expect("Error converting GeneralJsonIterator to GeneralJsonIteratorPython");
96131
f.to_object(py)
97132
}
98133
}
@@ -149,7 +184,13 @@ impl FromPyObject<'_> for PipelineSyncData {
149184
}
150185
}
151186

152-
impl FromPyObject<'_> for TransformerStream {
187+
impl FromPyObject<'_> for GeneralJsonAsyncIterator {
188+
fn extract(_ob: &PyAny) -> PyResult<Self> {
189+
panic!("We must implement this, but this is impossible to be reached")
190+
}
191+
}
192+
193+
impl FromPyObject<'_> for GeneralJsonIterator {
153194
fn extract(_ob: &PyAny) -> PyResult<Self> {
154195
panic!("We must implement this, but this is impossible to be reached")
155196
}

0 commit comments

Comments
 (0)