diff --git a/pgml-extension/requirements.txt b/pgml-extension/requirements.txt index 1884801e5..3fdfeb4b7 100644 --- a/pgml-extension/requirements.txt +++ b/pgml-extension/requirements.txt @@ -25,3 +25,4 @@ transformers==4.31.0 xgboost==1.7.6 langchain==0.0.237 einops==0.6.1 +pynvml==11.5.0 diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 85f71d3c8..d94e87de7 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -348,6 +348,8 @@ pub fn load_dataset( } pub fn clear_gpu_cache(memory_usage: Option) -> Result { + crate::bindings::venv::activate(); + Python::with_gil(|py| -> Result { let clear_gpu_cache: Py = PY_MODULE.getattr(py, "clear_gpu_cache")?; let success = clear_gpu_cache