@@ -1258,13 +1258,126 @@ pyo3::create_exception!(
1258
1258
"Custom Python Exception for Safetensor errors."
1259
1259
) ;
1260
1260
1261
+ #[ pyclass]
1262
+ #[ allow( non_camel_case_types) ]
1263
+ struct _safe_open_handle {
1264
+ inner : Option < Open > ,
1265
+ }
1266
+
1267
+ impl _safe_open_handle {
1268
+ fn inner ( & self ) -> PyResult < & Open > {
1269
+ let inner = self
1270
+ . inner
1271
+ . as_ref ( )
1272
+ . ok_or_else ( || SafetensorError :: new_err ( "File is closed" . to_string ( ) ) ) ?;
1273
+ Ok ( inner)
1274
+ }
1275
+ }
1276
+
1277
+ #[ pymethods]
1278
+ impl _safe_open_handle {
1279
+ #[ new]
1280
+ #[ pyo3( signature = ( f, framework, device=Some ( Device :: Cpu ) ) ) ]
1281
+ fn new ( f : PyObject , framework : Framework , device : Option < Device > ) -> PyResult < Self > {
1282
+ let filename = Python :: with_gil ( |py| -> PyResult < PathBuf > {
1283
+ let _ = f. getattr ( py, "fileno" ) ?;
1284
+ let filename = f. getattr ( py, "name" ) ?;
1285
+ let filename: PathBuf = filename. extract ( py) ?;
1286
+ Ok ( filename)
1287
+ } ) ?;
1288
+ let inner = Some ( Open :: new ( filename, framework, device) ?) ;
1289
+ Ok ( Self { inner } )
1290
+ }
1291
+
1292
+ /// Return the special non tensor information in the header
1293
+ ///
1294
+ /// Returns:
1295
+ /// (`Dict[str, str]`):
1296
+ /// The freeform metadata.
1297
+ pub fn metadata ( & self ) -> PyResult < Option < HashMap < String , String > > > {
1298
+ Ok ( self . inner ( ) ?. metadata ( ) )
1299
+ }
1300
+
1301
+ /// Returns the names of the tensors in the file.
1302
+ ///
1303
+ /// Returns:
1304
+ /// (`List[str]`):
1305
+ /// The name of the tensors contained in that file
1306
+ pub fn keys ( & self ) -> PyResult < Vec < String > > {
1307
+ self . inner ( ) ?. keys ( )
1308
+ }
1309
+
1310
+ /// Returns the names of the tensors in the file, ordered by offset.
1311
+ ///
1312
+ /// Returns:
1313
+ /// (`List[str]`):
1314
+ /// The name of the tensors contained in that file
1315
+ pub fn offset_keys ( & self ) -> PyResult < Vec < String > > {
1316
+ self . inner ( ) ?. offset_keys ( )
1317
+ }
1318
+
1319
+ /// Returns a full tensor
1320
+ ///
1321
+ /// Args:
1322
+ /// name (`str`):
1323
+ /// The name of the tensor you want
1324
+ ///
1325
+ /// Returns:
1326
+ /// (`Tensor`):
1327
+ /// The tensor in the framework you opened the file for.
1328
+ ///
1329
+ /// Example:
1330
+ /// ```python
1331
+ /// from safetensors import safe_open
1332
+ ///
1333
+ /// with safe_open("model.safetensors", framework="pt", device=0) as f:
1334
+ /// tensor = f.get_tensor("embedding")
1335
+ ///
1336
+ /// ```
1337
+ pub fn get_tensor ( & self , name : & str ) -> PyResult < PyObject > {
1338
+ self . inner ( ) ?. get_tensor ( name)
1339
+ }
1340
+
1341
+ /// Returns a full slice view object
1342
+ ///
1343
+ /// Args:
1344
+ /// name (`str`):
1345
+ /// The name of the tensor you want
1346
+ ///
1347
+ /// Returns:
1348
+ /// (`PySafeSlice`):
1349
+ /// A dummy object you can slice into to get a real tensor
1350
+ /// Example:
1351
+ /// ```python
1352
+ /// from safetensors import safe_open
1353
+ ///
1354
+ /// with safe_open("model.safetensors", framework="pt", device=0) as f:
1355
+ /// tensor_part = f.get_slice("embedding")[:, ::8]
1356
+ ///
1357
+ /// ```
1358
+ pub fn get_slice ( & self , name : & str ) -> PyResult < PySafeSlice > {
1359
+ self . inner ( ) ?. get_slice ( name)
1360
+ }
1361
+
1362
+ /// Start the context manager
1363
+ pub fn __enter__ ( slf : Py < Self > ) -> Py < Self > {
1364
+ slf
1365
+ }
1366
+
1367
+ /// Exits the context manager
1368
+ pub fn __exit__ ( & mut self , _exc_type : PyObject , _exc_value : PyObject , _traceback : PyObject ) {
1369
+ self . inner = None ;
1370
+ }
1371
+ }
1372
+
1261
1373
/// A Python module implemented in Rust.
1262
1374
#[ pymodule( gil_used = false ) ]
1263
1375
fn _safetensors_rust ( m : & PyBound < ' _ , PyModule > ) -> PyResult < ( ) > {
1264
1376
m. add_function ( wrap_pyfunction ! ( serialize, m) ?) ?;
1265
1377
m. add_function ( wrap_pyfunction ! ( serialize_file, m) ?) ?;
1266
1378
m. add_function ( wrap_pyfunction ! ( deserialize, m) ?) ?;
1267
1379
m. add_class :: < safe_open > ( ) ?;
1380
+ m. add_class :: < _safe_open_handle > ( ) ?;
1268
1381
m. add ( "SafetensorError" , m. py ( ) . get_type :: < SafetensorError > ( ) ) ?;
1269
1382
m. add ( "__version__" , env ! ( "CARGO_PKG_VERSION" ) ) ?;
1270
1383
Ok ( ( ) )
0 commit comments