Skip to content

Commit 48933f3

Browse files
authored
Adding _safe_open_handle. (#608)
* [WIP]. Adding safe_handle. * Adding S3. * File handle becomes private for merge.
1 parent 8814598 commit 48933f3

File tree

4 files changed

+186
-8
lines changed

4 files changed

+186
-8
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ repos:
3838
name: "Python (black)"
3939
args: ["--line-length", "119", "--target-version", "py35"]
4040
types: ["python"]
41-
- repo: https://github.com/pycqa/flake8
42-
rev: 7.2.0
41+
- repo: https://github.com/astral-sh/ruff-pre-commit
42+
# Ruff version.
43+
rev: v0.11.11
4344
hooks:
44-
- id: flake8
45-
args: ["--config", "bindings/python/setup.cfg"]
46-
- repo: https://github.com/pre-commit/mirrors-isort
47-
rev: v5.7.0 # Use the revision sha / tag you want to point at
48-
hooks:
49-
- id: isort
45+
# Run the linter.
46+
- id: ruff-check
47+
# Run the formatter.
48+
- id: ruff-format

bindings/python/py_src/safetensors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
__version__,
55
deserialize,
66
safe_open,
7+
_safe_open_handle,
78
serialize,
89
serialize_file,
910
)

bindings/python/src/lib.rs

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,13 +1258,126 @@ pyo3::create_exception!(
12581258
"Custom Python Exception for Safetensor errors."
12591259
);
12601260

1261+
#[pyclass]
1262+
#[allow(non_camel_case_types)]
1263+
struct _safe_open_handle {
1264+
inner: Option<Open>,
1265+
}
1266+
1267+
impl _safe_open_handle {
1268+
fn inner(&self) -> PyResult<&Open> {
1269+
let inner = self
1270+
.inner
1271+
.as_ref()
1272+
.ok_or_else(|| SafetensorError::new_err("File is closed".to_string()))?;
1273+
Ok(inner)
1274+
}
1275+
}
1276+
1277+
#[pymethods]
1278+
impl _safe_open_handle {
1279+
#[new]
1280+
#[pyo3(signature = (f, framework, device=Some(Device::Cpu)))]
1281+
fn new(f: PyObject, framework: Framework, device: Option<Device>) -> PyResult<Self> {
1282+
let filename = Python::with_gil(|py| -> PyResult<PathBuf> {
1283+
let _ = f.getattr(py, "fileno")?;
1284+
let filename = f.getattr(py, "name")?;
1285+
let filename: PathBuf = filename.extract(py)?;
1286+
Ok(filename)
1287+
})?;
1288+
let inner = Some(Open::new(filename, framework, device)?);
1289+
Ok(Self { inner })
1290+
}
1291+
1292+
/// Return the special non tensor information in the header
1293+
///
1294+
/// Returns:
1295+
/// (`Dict[str, str]`):
1296+
/// The freeform metadata.
1297+
pub fn metadata(&self) -> PyResult<Option<HashMap<String, String>>> {
1298+
Ok(self.inner()?.metadata())
1299+
}
1300+
1301+
/// Returns the names of the tensors in the file.
1302+
///
1303+
/// Returns:
1304+
/// (`List[str]`):
1305+
/// The name of the tensors contained in that file
1306+
pub fn keys(&self) -> PyResult<Vec<String>> {
1307+
self.inner()?.keys()
1308+
}
1309+
1310+
/// Returns the names of the tensors in the file, ordered by offset.
1311+
///
1312+
/// Returns:
1313+
/// (`List[str]`):
1314+
/// The name of the tensors contained in that file
1315+
pub fn offset_keys(&self) -> PyResult<Vec<String>> {
1316+
self.inner()?.offset_keys()
1317+
}
1318+
1319+
/// Returns a full tensor
1320+
///
1321+
/// Args:
1322+
/// name (`str`):
1323+
/// The name of the tensor you want
1324+
///
1325+
/// Returns:
1326+
/// (`Tensor`):
1327+
/// The tensor in the framework you opened the file for.
1328+
///
1329+
/// Example:
1330+
/// ```python
1331+
/// from safetensors import safe_open
1332+
///
1333+
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
1334+
/// tensor = f.get_tensor("embedding")
1335+
///
1336+
/// ```
1337+
pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
1338+
self.inner()?.get_tensor(name)
1339+
}
1340+
1341+
/// Returns a full slice view object
1342+
///
1343+
/// Args:
1344+
/// name (`str`):
1345+
/// The name of the tensor you want
1346+
///
1347+
/// Returns:
1348+
/// (`PySafeSlice`):
1349+
/// A dummy object you can slice into to get a real tensor
1350+
/// Example:
1351+
/// ```python
1352+
/// from safetensors import safe_open
1353+
///
1354+
/// with safe_open("model.safetensors", framework="pt", device=0) as f:
1355+
/// tensor_part = f.get_slice("embedding")[:, ::8]
1356+
///
1357+
/// ```
1358+
pub fn get_slice(&self, name: &str) -> PyResult<PySafeSlice> {
1359+
self.inner()?.get_slice(name)
1360+
}
1361+
1362+
/// Start the context manager
1363+
pub fn __enter__(slf: Py<Self>) -> Py<Self> {
1364+
slf
1365+
}
1366+
1367+
/// Exits the context manager
1368+
pub fn __exit__(&mut self, _exc_type: PyObject, _exc_value: PyObject, _traceback: PyObject) {
1369+
self.inner = None;
1370+
}
1371+
}
1372+
12611373
/// A Python module implemented in Rust.
12621374
#[pymodule(gil_used = false)]
12631375
fn _safetensors_rust(m: &PyBound<'_, PyModule>) -> PyResult<()> {
12641376
m.add_function(wrap_pyfunction!(serialize, m)?)?;
12651377
m.add_function(wrap_pyfunction!(serialize_file, m)?)?;
12661378
m.add_function(wrap_pyfunction!(deserialize, m)?)?;
12671379
m.add_class::<safe_open>()?;
1380+
m.add_class::<_safe_open_handle>()?;
12681381
m.add("SafetensorError", m.py().get_type::<SafetensorError>())?;
12691382
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
12701383
Ok(())

bindings/python/tests/test_handle.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
from safetensors import _safe_open_handle
6+
from safetensors.numpy import save_file, save
7+
8+
9+
class ReadmeTestCase(unittest.TestCase):
10+
def assertTensorEqual(self, tensors1, tensors2, equality_fn):
11+
self.assertEqual(tensors1.keys(), tensors2.keys(), "tensor keys don't match")
12+
13+
for k, v1 in tensors1.items():
14+
v2 = tensors2[k]
15+
16+
self.assertTrue(equality_fn(v1, v2), f"{k} tensors are different")
17+
18+
def test_numpy_example(self):
19+
tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}
20+
21+
save_file(tensors, "./out_np.safetensors")
22+
23+
# Now loading
24+
loaded = {}
25+
with open("./out_np.safetensors", "r") as f:
26+
with safe_open_handle(f, framework="np", device="cpu") as g:
27+
for key in g.keys():
28+
loaded[key] = g.get_tensor(key)
29+
self.assertTensorEqual(tensors, loaded, np.allclose)
30+
31+
def test_fsspec(self):
32+
import fsspec
33+
34+
tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}
35+
36+
fs = fsspec.filesystem("file")
37+
byts = save(tensors)
38+
with fs.open("fs.safetensors", "wb") as f:
39+
f.write(byts)
40+
# Now loading
41+
loaded = {}
42+
with fs.open("fs.safetensors", "rb") as f:
43+
with safe_open_handle(f, framework="np", device="cpu") as g:
44+
for key in g.keys():
45+
loaded[key] = g.get_tensor(key)
46+
self.assertTensorEqual(tensors, loaded, np.allclose)
47+
48+
@unittest.skip("Will not work without s3 access")
49+
def test_fsspec_s3(self):
50+
import s3fs
51+
52+
tensors = {"a": np.zeros((2, 2)), "b": np.zeros((2, 3), dtype=np.uint8)}
53+
54+
s3 = s3fs.S3FileSystem(anon=True)
55+
byts = save(tensors)
56+
print(s3.ls("my-bucket"))
57+
with s3.open("out/fs.safetensors", "wb") as f:
58+
f.write(byts)
59+
# Now loading
60+
loaded = {}
61+
with s3.open("out/fs.safetensors", "rb") as f:
62+
with safe_open_handle(f, framework="np", device="cpu") as g:
63+
for key in g.keys():
64+
loaded[key] = g.get_tensor(key)
65+
self.assertTensorEqual(tensors, loaded, np.allclose)

0 commit comments

Comments
 (0)