From a1dda883b2772a350f809bb715d41191e0b38a29 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 13:39:11 -0800 Subject: [PATCH 1/2] General clean ups --- pgml-sdks/pgml/build.rs | 4 ++++ pgml-sdks/pgml/python/tests/test.py | 1 - pgml-sdks/pgml/src/languages/javascript.rs | 6 +++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pgml-sdks/pgml/build.rs b/pgml-sdks/pgml/build.rs index 82b51670c..f017a04db 100644 --- a/pgml-sdks/pgml/build.rs +++ b/pgml-sdks/pgml/build.rs @@ -8,6 +8,8 @@ async def migrate() -> None Json = Any DateTime = int +GeneralJsonIterator = Any +GeneralJsonAsyncIterator = Any "#; const ADDITIONAL_DEFAULTS_FOR_JAVASCRIPT: &[u8] = br#" @@ -16,6 +18,8 @@ export function migrate(): Promise; export type Json = any; export type DateTime = Date; +export type GeneralJsonIterator = any; +export type GeneralJsonAsyncIterator = any; export function newCollection(name: string, database_url?: string): Collection; export function newModel(name?: string, source?: string, parameters?: Json): Model; diff --git a/pgml-sdks/pgml/python/tests/test.py b/pgml-sdks/pgml/python/tests/test.py index f3b1fbec9..5c3a4df33 100644 --- a/pgml-sdks/pgml/python/tests/test.py +++ b/pgml-sdks/pgml/python/tests/test.py @@ -361,7 +361,6 @@ async def test_open_source_ai_create_async(): ], temperature=0.85, ) - import json assert len(results["choices"]) > 0 diff --git a/pgml-sdks/pgml/src/languages/javascript.rs b/pgml-sdks/pgml/src/languages/javascript.rs index c9a09326d..c49b5c493 100644 --- a/pgml-sdks/pgml/src/languages/javascript.rs +++ b/pgml-sdks/pgml/src/languages/javascript.rs @@ -101,13 +101,13 @@ fn transform_stream_iterate_next(mut cx: FunctionContext) -> JsResult .expect("Error converting rust Json to JavaScript Object"); let d = cx.boolean(false); o.set(&mut cx, "value", v) - .expect("Error setting object value in transform_sream_iterate_next"); + .expect("Error setting object value in transform_stream_iterate_next"); o.set(&mut cx, "done", d) - .expect("Error setting object value in transform_sream_iterate_next"); + .expect("Error setting object value in transform_stream_iterate_next"); } else { let d = cx.boolean(true); o.set(&mut cx, "done", d) - .expect("Error setting object value in transform_sream_iterate_next"); + .expect("Error setting object value in transform_stream_iterate_next"); } Ok(o) }) From 2f26eda585ee3a22b7874d06415ca60795efd2d0 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:32:33 -0800 Subject: [PATCH 2/2] Made the text iterator streamer return spaces correctly --- .../src/bindings/transformers/transformers.py | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 5c6078785..c738fe1f5 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -104,19 +104,42 @@ def __init__(self, tokenizer, skip_prompt=False, timeout=None, **decode_kwargs): self.next_tokens_are_prompt = True self.stop_signal = None self.text_queue = queue.Queue() + self.token_cache = [] + self.text_index_cache = [] - def put(self, value): + def put(self, values): if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return - # Can't batch this decode - decoded_values = [] - for v in value: - decoded_values.append(self.tokenizer.decode(v, **self.decode_kwargs)) - self.text_queue.put(decoded_values, self.timeout) + output = [] + for i, v in enumerate(values): + if len(self.token_cache) <= i: + self.token_cache.append([]) + self.text_index_cache.append(0) + token = v.tolist() # Returns a list or number + if type(token) == list: + self.token_cache[i].extend(token) + else: + self.token_cache[i].append(token) + text = self.tokenizer.decode(self.token_cache[i], **self.decode_kwargs) + if text.endswith("\n"): + output.append(text[self.text_index_cache[i] :]) + self.token_cache[i] = [] + self.text_index_cache[i] = 0 + else: + printable_text = text[self.text_index_cache[i] : text.rfind(" ") + 1] + self.text_index_cache[i] += len(printable_text) + output.append(printable_text) + if any(output): + self.text_queue.put(output, self.timeout) def end(self): self.next_tokens_are_prompt = True + output = [] + for i, tokens in enumerate(self.token_cache): + text = self.tokenizer.decode(tokens, **self.decode_kwargs) + output.append(text[self.text_index_cache[i] :]) + self.text_queue.put(output, self.timeout) self.text_queue.put(self.stop_signal, self.timeout) def __iter__(self): @@ -127,6 +150,7 @@ def __next__(self): if value != self.stop_signal: return value + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -245,7 +269,8 @@ def stream(self, input, **kwargs): generation_kwargs = None if self.task == "conversational": streamer = TextIteratorStreamer( - self.tokenizer, skip_prompt=True, skip_special_tokens=True + self.tokenizer, + skip_prompt=True, ) if "chat_template" in kwargs: input = self.tokenizer.apply_chat_template( @@ -261,7 +286,7 @@ def stream(self, input, **kwargs): input = self.tokenizer(input, return_tensors="pt").to(self.model.device) generation_kwargs = dict(input, streamer=streamer, **kwargs) else: - streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) + streamer = TextIteratorStreamer(self.tokenizer) input = self.tokenizer(input, return_tensors="pt", padding=True).to( self.model.device )