From 8a72cfff496fae16ade65166c9a79448d77b2ff2 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Tue, 19 Sep 2023 09:53:42 -0500 Subject: [PATCH 1/3] separate embed model creation and usage --- .../src/bindings/transformers/transformers.py | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index f220be89d..98c843691 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -241,29 +241,38 @@ def transform(task, args, inputs): return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() -def embed(transformer, inputs, kwargs): - kwargs = orjson.loads(kwargs) +def create_embedding(transformer): + instructor = transformer.startswith("hkunlp/instructor") + klass = INSTRUCTOR if instructor else SentenceTransformer + return klass(transformer) + + +def embed_using(model, transformer, inputs, kwargs): + if isinstance(kwargs, str): + kwargs = orjson.loads(kwargs) - ensure_device(kwargs) instructor = transformer.startswith("hkunlp/instructor") - if instructor: - klass = INSTRUCTOR - texts_with_instructions = [] instruction = kwargs.pop("instruction") for text in inputs: texts_with_instructions.append([instruction, text]) inputs = texts_with_instructions - else: - klass = SentenceTransformer + + return model.encode(inputs, **kwargs) + + +def embed(transformer, inputs, kwargs): + kwargs = orjson.loads(kwargs) + + ensure_device(kwargs) if transformer not in __cache_sentence_transformer_by_name: - __cache_sentence_transformer_by_name[transformer] = klass(transformer) + __cache_sentence_transformer_by_name[transformer] = create_embedding(transformer) model = __cache_sentence_transformer_by_name[transformer] - return model.encode(inputs, **kwargs) + return embed_using(model, transformer, inputs, kwargs) def clear_gpu_cache(memory_usage: None): From 5bf5d001396268eba87c7e0c848f0b08f0659cdc Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Tue, 19 Sep 2023 11:16:38 -0500 Subject: [PATCH 2/3] fix dead code --- pgml-extension/src/bindings/langchain/mod.rs | 3 +-- pgml-extension/src/bindings/python/mod.rs | 3 +-- pgml-extension/src/bindings/sklearn/mod.rs | 7 +------ pgml-extension/src/bindings/transformers/mod.rs | 1 - 4 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pgml-extension/src/bindings/langchain/mod.rs b/pgml-extension/src/bindings/langchain/mod.rs index 00ee593fd..7d8d2582f 100644 --- a/pgml-extension/src/bindings/langchain/mod.rs +++ b/pgml-extension/src/bindings/langchain/mod.rs @@ -1,10 +1,9 @@ use anyhow::Result; -use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::{bindings::TracebackError, create_pymodule}; +use crate::create_pymodule; create_pymodule!("/src/bindings/langchain/langchain.py"); diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index 7f527b0fc..9ab7300c0 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -1,14 +1,13 @@ //! Use virtualenv. use anyhow::Result; -use once_cell::sync::Lazy; use pgrx::iter::TableIterator; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; use crate::config::get_config; -use crate::{bindings::TracebackError, create_pymodule}; +use crate::create_pymodule; static CONFIG_NAME: &str = "pgml.venv"; diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index 05e85d97c..4b8ce6625 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -11,15 +11,10 @@ use pgrx::*; use std::collections::HashMap; use anyhow::Result; -use once_cell::sync::Lazy; use pyo3::prelude::*; use pyo3::types::PyTuple; -use crate::{ - bindings::{Bindings, TracebackError}, - create_pymodule, - orm::*, -}; +use crate::{bindings::Bindings, create_pymodule, orm::*}; create_pymodule!("/src/bindings/sklearn/sklearn.py"); diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 91158f860..fbdeec4f8 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -4,7 +4,6 @@ use std::str::FromStr; use std::{collections::HashMap, path::Path}; use anyhow::{anyhow, bail, Context, Result}; -use once_cell::sync::Lazy; use pgrx::*; use pyo3::prelude::*; use pyo3::types::PyTuple; From 4e67b52774646f440513b3331bc7a9fe2381f403 Mon Sep 17 00:00:00 2001 From: Kevin Zimmerman <4733573+kczimm@users.noreply.github.com> Date: Tue, 19 Sep 2023 11:17:19 -0500 Subject: [PATCH 3/3] fix clippy lints --- .../src/bindings/transformers/mod.rs | 18 +++++++++--------- pgml-extension/src/orm/model.rs | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index fbdeec4f8..c4e262761 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -46,22 +46,22 @@ pub fn transform( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) })?; Ok(serde_json::from_str(&results)?) } pub fn get_model_from(task: &Value) -> Result { - Ok(Python::with_gil(|py| -> Result { + Python::with_gil(|py| -> Result { let get_model_from = get_module!(PY_MODULE) .getattr(py, "get_model_from") .format_traceback(py)?; let model = get_model_from .call1(py, PyTuple::new(py, &[task.to_string().into_py(py)])) .format_traceback(py)?; - Ok(model.extract(py).format_traceback(py)?) - })?) + model.extract(py).format_traceback(py) + }) } pub fn embed( @@ -90,7 +90,7 @@ pub fn embed( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) }) } @@ -125,7 +125,7 @@ pub fn tune( ) .format_traceback(py)?; - Ok(output.extract(py).format_traceback(py)?) + output.extract(py).format_traceback(py) }) } @@ -175,7 +175,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result o, }; - Ok(result.extract(py).format_traceback(py)?) + result.extract(py).format_traceback(py) }) } @@ -226,7 +226,7 @@ pub fn load_dataset( let load_dataset: Py = get_module!(PY_MODULE) .getattr(py, "load_dataset") .format_traceback(py)?; - Ok(load_dataset + load_dataset .call1( py, PyTuple::new( @@ -241,7 +241,7 @@ pub fn load_dataset( ) .format_traceback(py)? .extract(py) - .format_traceback(py)?) + .format_traceback(py) })?; let table_name = format!("pgml.\"{}\"", name); diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index f87ff736a..89a23888c 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -378,12 +378,12 @@ impl Model { Ok(()) })?; - Ok(model.ok_or_else(|| { + model.ok_or_else(|| { anyhow!( "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id ) - })?) + }) } pub fn find_cached(id: i64) -> Result> {