Skip to content

Commit 384921f

Browse files
authored
separate embed model creation and usage (#1022)
1 parent 3b088a4 commit 384921f

File tree

6 files changed

+33
-32
lines changed

6 files changed

+33
-32
lines changed

pgml-extension/src/bindings/langchain/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use anyhow::Result;
2-
use once_cell::sync::Lazy;
32
use pgrx::*;
43
use pyo3::prelude::*;
54
use pyo3::types::PyTuple;
65

7-
use crate::{bindings::TracebackError, create_pymodule};
6+
use crate::create_pymodule;
87

98
create_pymodule!("/src/bindings/langchain/langchain.py");
109

pgml-extension/src/bindings/python/mod.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
//! Use virtualenv.
22
33
use anyhow::Result;
4-
use once_cell::sync::Lazy;
54
use pgrx::iter::TableIterator;
65
use pgrx::*;
76
use pyo3::prelude::*;
87
use pyo3::types::PyTuple;
98

109
use crate::config::get_config;
11-
use crate::{bindings::TracebackError, create_pymodule};
10+
use crate::create_pymodule;
1211

1312
static CONFIG_NAME: &str = "pgml.venv";
1413

pgml-extension/src/bindings/sklearn/mod.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,10 @@ use pgrx::*;
1111
use std::collections::HashMap;
1212

1313
use anyhow::Result;
14-
use once_cell::sync::Lazy;
1514
use pyo3::prelude::*;
1615
use pyo3::types::PyTuple;
1716

18-
use crate::{
19-
bindings::{Bindings, TracebackError},
20-
create_pymodule,
21-
orm::*,
22-
};
17+
use crate::{bindings::Bindings, create_pymodule, orm::*};
2318

2419
create_pymodule!("/src/bindings/sklearn/sklearn.py");
2520

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::str::FromStr;
44
use std::{collections::HashMap, path::Path};
55

66
use anyhow::{anyhow, bail, Context, Result};
7-
use once_cell::sync::Lazy;
87
use pgrx::*;
98
use pyo3::prelude::*;
109
use pyo3::types::PyTuple;
@@ -47,22 +46,22 @@ pub fn transform(
4746
)
4847
.format_traceback(py)?;
4948

50-
Ok(output.extract(py).format_traceback(py)?)
49+
output.extract(py).format_traceback(py)
5150
})?;
5251

5352
Ok(serde_json::from_str(&results)?)
5453
}
5554

5655
pub fn get_model_from(task: &Value) -> Result<String> {
57-
Ok(Python::with_gil(|py| -> Result<String> {
56+
Python::with_gil(|py| -> Result<String> {
5857
let get_model_from = get_module!(PY_MODULE)
5958
.getattr(py, "get_model_from")
6059
.format_traceback(py)?;
6160
let model = get_model_from
6261
.call1(py, PyTuple::new(py, &[task.to_string().into_py(py)]))
6362
.format_traceback(py)?;
64-
Ok(model.extract(py).format_traceback(py)?)
65-
})?)
63+
model.extract(py).format_traceback(py)
64+
})
6665
}
6766

6867
pub fn embed(
@@ -91,7 +90,7 @@ pub fn embed(
9190
)
9291
.format_traceback(py)?;
9392

94-
Ok(output.extract(py).format_traceback(py)?)
93+
output.extract(py).format_traceback(py)
9594
})
9695
}
9796

@@ -126,7 +125,7 @@ pub fn tune(
126125
)
127126
.format_traceback(py)?;
128127

129-
Ok(output.extract(py).format_traceback(py)?)
128+
output.extract(py).format_traceback(py)
130129
})
131130
}
132131

@@ -176,7 +175,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result<Vec<S
176175
}
177176
Ok(o) => o,
178177
};
179-
Ok(result.extract(py).format_traceback(py)?)
178+
result.extract(py).format_traceback(py)
180179
})
181180
}
182181

@@ -227,7 +226,7 @@ pub fn load_dataset(
227226
let load_dataset: Py<PyAny> = get_module!(PY_MODULE)
228227
.getattr(py, "load_dataset")
229228
.format_traceback(py)?;
230-
Ok(load_dataset
229+
load_dataset
231230
.call1(
232231
py,
233232
PyTuple::new(
@@ -242,7 +241,7 @@ pub fn load_dataset(
242241
)
243242
.format_traceback(py)?
244243
.extract(py)
245-
.format_traceback(py)?)
244+
.format_traceback(py)
246245
})?;
247246

248247
let table_name = format!("pgml.\"{}\"", name);

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,29 +241,38 @@ def transform(task, args, inputs):
241241
return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode()
242242

243243

244-
def embed(transformer, inputs, kwargs):
245-
kwargs = orjson.loads(kwargs)
244+
def create_embedding(transformer):
245+
instructor = transformer.startswith("hkunlp/instructor")
246+
klass = INSTRUCTOR if instructor else SentenceTransformer
247+
return klass(transformer)
248+
249+
250+
def embed_using(model, transformer, inputs, kwargs):
251+
if isinstance(kwargs, str):
252+
kwargs = orjson.loads(kwargs)
246253

247-
ensure_device(kwargs)
248254
instructor = transformer.startswith("hkunlp/instructor")
249-
250255
if instructor:
251-
klass = INSTRUCTOR
252-
253256
texts_with_instructions = []
254257
instruction = kwargs.pop("instruction")
255258
for text in inputs:
256259
texts_with_instructions.append([instruction, text])
257260

258261
inputs = texts_with_instructions
259-
else:
260-
klass = SentenceTransformer
262+
263+
return model.encode(inputs, **kwargs)
264+
265+
266+
def embed(transformer, inputs, kwargs):
267+
kwargs = orjson.loads(kwargs)
268+
269+
ensure_device(kwargs)
261270

262271
if transformer not in __cache_sentence_transformer_by_name:
263-
__cache_sentence_transformer_by_name[transformer] = klass(transformer)
272+
__cache_sentence_transformer_by_name[transformer] = create_embedding(transformer)
264273
model = __cache_sentence_transformer_by_name[transformer]
265274

266-
return model.encode(inputs, **kwargs)
275+
return embed_using(model, transformer, inputs, kwargs)
267276

268277

269278
def clear_gpu_cache(memory_usage: None):

pgml-extension/src/orm/model.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ impl Model {
378378
Ok(())
379379
})?;
380380

381-
Ok(model.ok_or_else(|| {
381+
model.ok_or_else(|| {
382382
anyhow!(
383383
"pgml.models WHERE id = {:?} could not be loaded. Does it exist?",
384384
id
385385
)
386-
})?)
386+
})
387387
}
388388

389389
pub fn find_cached(id: i64) -> Result<Arc<Model>> {

0 commit comments

Comments
 (0)