Skip to content

Commit e3892bc

Browse files
authored
separate pipeline creation and usage for transform (#921)
1 parent f2f5b44 commit e3892bc

File tree

3 files changed

+40
-20
lines changed

3 files changed

+40
-20
lines changed

pgml-dashboard/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/src/bindings/mod.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,18 @@ pub trait Bindings: Send + Sync + Debug {
7373
Self: Sized;
7474
}
7575

76-
trait TracebackError<T> {
76+
pub trait TracebackError<T> {
7777
fn format_traceback(self, py: Python<'_>) -> Result<T>;
7878
}
7979

8080
impl<T> TracebackError<T> for PyResult<T> {
8181
fn format_traceback(self, py: Python<'_>) -> Result<T> {
82-
self.map_err(|e| {
83-
let traceback = e.traceback(py).unwrap().format().unwrap();
84-
anyhow!("{traceback} {e}")
82+
self.map_err(|e| match e.traceback(py) {
83+
Some(traceback) => match traceback.format() {
84+
Ok(traceback) => anyhow!("{traceback} {e}"),
85+
Err(format_e) => anyhow!("{e} {format_e}"),
86+
},
87+
None => anyhow!("{e}"),
8588
})
8689
}
8790
}

pgml-extension/src/bindings/transformers/transformers.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -184,28 +184,45 @@ def get_model_from(task):
184184
return model[ty][0]
185185

186186

187+
def create_pipeline(task):
188+
if isinstance(task, str):
189+
task = orjson.loads(task)
190+
ensure_device(task)
191+
convert_dtype(task)
192+
model_name = task.get("model", None)
193+
if model_name and "-ggml" in model_name:
194+
pipe = GGMLPipeline(model_name, **task)
195+
elif model_name and "-gptq" in model_name:
196+
pipe = GPTQPipeline(model_name, **task)
197+
else:
198+
try:
199+
pipe = StandardPipeline(model_name, **task)
200+
except TypeError:
201+
# some models fail when given "device" kwargs, remove and try again
202+
task.pop("device")
203+
pipe = StandardPipeline(model_name, **task)
204+
return pipe
205+
206+
207+
def transform_using(pipeline, args, inputs):
208+
args = orjson.loads(args)
209+
inputs = orjson.loads(inputs)
210+
211+
if pipeline.task == "question-answering":
212+
inputs = [orjson.loads(input) for input in inputs]
213+
convert_eos_token(pipeline.tokenizer, args)
214+
215+
return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode()
216+
217+
187218
def transform(task, args, inputs):
188219
task = orjson.loads(task)
189220
args = orjson.loads(args)
190221
inputs = orjson.loads(inputs)
191222

192223
key = ",".join([f"{key}:{val}" for (key, val) in sorted(task.items())])
193224
if key not in __cache_transform_pipeline_by_task:
194-
ensure_device(task)
195-
convert_dtype(task)
196-
model_name = task.get("model", None)
197-
if model_name and "-ggml" in model_name:
198-
pipe = GGMLPipeline(model_name, **task)
199-
elif model_name and "-gptq" in model_name:
200-
pipe = GPTQPipeline(model_name, **task)
201-
else:
202-
try:
203-
pipe = StandardPipeline(model_name, **task)
204-
except TypeError:
205-
# some models fail when given "device" kwargs, remove and try again
206-
task.pop("device")
207-
pipe = StandardPipeline(model_name, **task)
208-
225+
pipe = create_pipeline(task)
209226
__cache_transform_pipeline_by_task[key] = pipe
210227

211228
pipe = __cache_transform_pipeline_by_task[key]

0 commit comments

Comments
 (0)