From 7ee31ea92060da8b3695db5d550051d8857056dd Mon Sep 17 00:00:00 2001 From: Lev Date: Tue, 12 Mar 2024 19:13:53 -0700 Subject: [PATCH] Move task verification out of internal binding --- pgml-extension/src/api.rs | 13 +++++++++++++ .../src/bindings/transformers/transform.rs | 2 -- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 7fd5012c8..a9167fe5d 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -647,6 +647,10 @@ pub fn transform_json( inputs: default!(Vec<&str>, "ARRAY[]::TEXT[]"), cache: default!(bool, false), ) -> JsonB { + if let Err(err) = crate::bindings::transformers::whitelist::verify_task(&task.0) { + error!("{err}"); + } + match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), @@ -663,6 +667,9 @@ pub fn transform_string( cache: default!(bool, false), ) -> JsonB { let task_json = json!({ "task": task }); + if let Err(err) = crate::bindings::transformers::whitelist::verify_task(&task_json) { + error!("{err}"); + } match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), @@ -681,6 +688,9 @@ pub fn transform_conversational_json( if !task.0["task"].as_str().is_some_and(|v| v == "conversational") { error!("ARRAY[]::JSONB inputs for transform should only be used with a conversational task"); } + if let Err(err) = crate::bindings::transformers::whitelist::verify_task(&task.0) { + error!("{err}"); + } match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), @@ -700,6 +710,9 @@ pub fn transform_conversational_string( error!("ARRAY[]::JSONB inputs for transform should only be used with a conversational task"); } let task_json = json!({ "task": task }); + if let Err(err) = crate::bindings::transformers::whitelist::verify_task(&task_json) { + error!("{err}"); + } match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { Ok(output) => JsonB(output), Err(e) => error!("{e}"), diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index 41fd04512..7b8db768e 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -46,8 +46,6 @@ pub fn transform( args: &serde_json::Value, inputs: T, ) -> Result { - whitelist::verify_task(task)?; - let task = serde_json::to_string(task)?; let args = serde_json::to_string(args)?; let inputs = serde_json::to_string(&inputs)?;