Skip to content

Commit 1ca5dc8

Browse files
committed
Clean up errors and guard rails around conversational api
1 parent 3ff2f07 commit 1ca5dc8

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

pgml-extension/src/api.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,14 @@ pub fn transform_conversational_json(
641641
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
642642
cache: default!(bool, false),
643643
) -> JsonB {
644+
if !task.0["task"]
645+
.as_str()
646+
.is_some_and(|v| v == "conversational")
647+
{
648+
error!(
649+
"ARRAY[]::JSONB inputs for transformer should only be used with a conversational task"
650+
);
651+
}
644652
match crate::bindings::transformers::transform(&task.0, &args.0, inputs) {
645653
Ok(output) => JsonB(output),
646654
Err(e) => error!("{e}"),
@@ -656,6 +664,11 @@ pub fn transform_conversational_string(
656664
inputs: default!(Vec<JsonB>, "ARRAY[]::JSONB[]"),
657665
cache: default!(bool, false),
658666
) -> JsonB {
667+
if task != "conversational" {
668+
error!(
669+
"ARRAY[]::JSONB inputs for transformer should only be used with a conversational task"
670+
);
671+
}
659672
let task_json = json!({ "task": task });
660673
match crate::bindings::transformers::transform(&task_json, &args.0, inputs) {
661674
Ok(output) => JsonB(output),
@@ -710,12 +723,13 @@ pub fn transform_stream_conversational_json(
710723
input: default!(JsonB, "'[]'::JSONB"),
711724
cache: default!(bool, false),
712725
) -> SetOfIterator<'static, String> {
713-
// If they have Vec<JsonB> inputs lets make sure they have the write task
714726
if !task.0["task"]
715727
.as_str()
716728
.is_some_and(|v| v == "conversational")
717729
{
718-
error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task");
730+
error!(
731+
"JSONB inputs for transformer_stream should only be used with a conversational task"
732+
);
719733
}
720734
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call
721735
let python_iter =
@@ -735,7 +749,9 @@ pub fn transform_stream_conversational_string(
735749
cache: default!(bool, false),
736750
) -> SetOfIterator<'static, String> {
737751
if task != "conversational" {
738-
error!("ARRAY[]::JSONB inputs for transformer_stream should only be used with a conversational task");
752+
error!(
753+
"JSONB inputs for transformer_stream should only be used with a conversational task"
754+
);
739755
}
740756
let task_json = json!({ "task": task });
741757
// We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call

0 commit comments

Comments
 (0)