Skip to content

Commit 160fe57

Browse files
authored
Fix clear_gpu_cache (#898)
1 parent d7fb281 commit 160fe57

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

pgml-extension/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ transformers==4.31.0
2525
xgboost==1.7.6
2626
langchain==0.0.237
2727
einops==0.6.1
28+
pynvml==11.5.0

pgml-extension/src/bindings/transformers/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ pub fn load_dataset(
348348
}
349349

350350
pub fn clear_gpu_cache(memory_usage: Option<f32>) -> Result<bool> {
351+
crate::bindings::venv::activate();
352+
351353
Python::with_gil(|py| -> Result<bool> {
352354
let clear_gpu_cache: Py<PyAny> = PY_MODULE.getattr(py, "clear_gpu_cache")?;
353355
let success = clear_gpu_cache

0 commit comments

Comments
 (0)