Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address review comments.
- Set the minimum value of OMP_NUM_THREADS to 1.
- Move PGML_OMP_NUM_THREADS.get() into initialize_server_params().
  • Loading branch information
higuoxing committed Apr 10, 2024
commit c81c42e2a64588dd5496a27d2e36b65222584dd9
13 changes: 11 additions & 2 deletions pgml-extension/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ pub static PGML_HF_WHITELIST: GucSetting<Option<&'static CStr>> = GucSetting::<O
pub static PGML_HF_TRUST_REMOTE_CODE: GucSetting<bool> = GucSetting::<bool>::new(false);
pub static PGML_HF_TRUST_REMOTE_CODE_WHITELIST: GucSetting<Option<&'static CStr>> =
GucSetting::<Option<&'static CStr>>::new(None);
pub static PGML_OMP_NUM_THREADS: GucSetting<i32> = GucSetting::<i32>::new(0);
pub static PGML_OMP_NUM_THREADS: GucSetting<i32> = GucSetting::<i32>::new(1);

extern "C" {
fn omp_set_num_threads(num_threads: i32);
}

pub fn initialize_server_params() {
GucRegistry::define_string_guc(
Expand Down Expand Up @@ -53,11 +57,16 @@ pub fn initialize_server_params() {
"Specifies the number of threads used by default of underlying OpenMP library. Only positive integers are valid",
"",
&PGML_OMP_NUM_THREADS,
0,
1,
i32::max_value(),
GucContext::Backend,
GucFlags::default(),
);

let omp_num_threads = PGML_OMP_NUM_THREADS.get();
unsafe {
omp_set_num_threads(omp_num_threads);
}
}

#[cfg(any(test, feature = "pg_test"))]
Expand Down
10 changes: 0 additions & 10 deletions pgml-extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,10 @@ pg_module_magic!();

extension_sql_file!("../sql/schema.sql", name = "schema");

extern "C" {
fn omp_set_num_threads(num_threads: i32);
}

#[cfg(not(feature = "use_as_lib"))]
#[pg_guard]
pub extern "C" fn _PG_init() {
config::initialize_server_params();
let omp_num_threads = config::PGML_OMP_NUM_THREADS.get();
if omp_num_threads > 0 {
unsafe {
omp_set_num_threads(omp_num_threads);
}
}
bindings::python::activate().expect("Error setting python venv");
orm::project::init();
}
Expand Down