Skip to content

Commit 56cb69b

Browse files
levkkSilasMarvin
andauthored
OpenAI runtime (#944)
Co-authored-by: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com>
1 parent 1d6119e commit 56cb69b

File tree

8 files changed

+23
-3
lines changed

8 files changed

+23
-3
lines changed

pgml-extension/Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pgml-extension/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pgml"
3-
version = "2.7.7"
3+
version = "2.7.8"
44
edition = "2021"
55

66
[lib]

pgml-extension/sql/pgml--2.7.7--2.7.8.sql

Whitespace-only changes.

pgml-extension/sql/schema.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ SELECT pgml.auto_updated_at('pgml.snapshots');
8282
CREATE TABLE IF NOT EXISTS pgml.models(
8383
id BIGSERIAL PRIMARY KEY,
8484
project_id BIGINT NOT NULL,
85-
snapshot_id BIGINT NOT NULL,
85+
snapshot_id BIGINT,
8686
num_features INT NOT NULL,
8787
algorithm TEXT NOT NULL,
8888
runtime pgml.runtime DEFAULT 'python'::pgml.runtime,

pgml-extension/src/orm/file.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ pub fn find_deployed_estimator_by_model_id(model_id: i64) -> Result<Arc<Box<dyn
115115
Runtime::python => {
116116
anyhow::bail!("Python runtime not supported, recompile with `--features python`")
117117
}
118+
119+
Runtime::openai => {
120+
error!("OpenAI runtime is not supported for training or inference");
121+
}
118122
};
119123

120124
// Cache the estimator in process memory.

pgml-extension/src/orm/model.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,10 @@ impl Model {
306306
.unwrap();
307307

308308
let bindings: Box<dyn Bindings> = match runtime {
309+
Runtime::openai => {
310+
error!("OpenAI runtime is not supported for training or inference");
311+
}
312+
309313
Runtime::rust => {
310314
match algorithm {
311315
Algorithm::xgboost => {
@@ -396,6 +400,10 @@ impl Model {
396400

397401
fn get_fit_function(&self) -> crate::bindings::Fit {
398402
match self.runtime {
403+
Runtime::openai => {
404+
error!("OpenAI runtime is not supported for training or inference");
405+
}
406+
399407
Runtime::rust => match self.project.task {
400408
Task::regression => match self.algorithm {
401409
Algorithm::xgboost => xgboost::fit_regression,

pgml-extension/src/orm/runtime.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use serde::Deserialize;
66
pub enum Runtime {
77
python,
88
rust,
9+
openai,
910
}
1011

1112
impl std::str::FromStr for Runtime {
@@ -15,6 +16,7 @@ impl std::str::FromStr for Runtime {
1516
match input {
1617
"python" => Ok(Runtime::python),
1718
"rust" => Ok(Runtime::rust),
19+
"openai" => Ok(Runtime::openai),
1820
_ => Err(()),
1921
}
2022
}
@@ -25,6 +27,7 @@ impl std::string::ToString for Runtime {
2527
match *self {
2628
Runtime::python => "python".to_string(),
2729
Runtime::rust => "rust".to_string(),
30+
Runtime::openai => "openai".to_string(),
2831
}
2932
}
3033
}

pgml-extension/src/orm/task.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub enum Task {
1313
text_generation,
1414
text2text,
1515
cluster,
16+
embedding,
1617
}
1718

1819
// unfortunately the pgrx macro expands the enum names to underscore, but huggingface uses dash
@@ -28,6 +29,7 @@ impl Task {
2829
Task::text_generation => "text_generation".to_string(),
2930
Task::text2text => "text2text".to_string(),
3031
Task::cluster => "cluster".to_string(),
32+
Task::embedding => "embedding".to_string(),
3133
}
3234
}
3335

@@ -46,6 +48,7 @@ impl Task {
4648
Task::text_generation => "perplexity",
4749
Task::text2text => "perplexity",
4850
Task::cluster => "silhouette",
51+
Task::embedding => error!("No default target metric for embedding task")
4952
}
5053
.to_string()
5154
}
@@ -61,6 +64,7 @@ impl Task {
6164
Task::text_generation => false,
6265
Task::text2text => false,
6366
Task::cluster => true,
67+
Task::embedding => error!("No default target metric positive for embedding task")
6468
}
6569
}
6670

@@ -117,6 +121,7 @@ impl std::string::ToString for Task {
117121
Task::text_generation => "text-generation".to_string(),
118122
Task::text2text => "text2text".to_string(),
119123
Task::cluster => "cluster".to_string(),
124+
Task::embedding => "embedding".to_string(),
120125
}
121126
}
122127
}

0 commit comments

Comments
 (0)