5
5
import json
6
6
import logging
7
7
import math
8
- import mmap
9
8
import os
10
9
import struct
11
10
import time
17
16
from torch .distributed .checkpoint ._hf_utils import (
18
17
_gen_file_name ,
19
18
_get_dcp_custom_metadata ,
20
- _get_dtype ,
21
19
_get_safetensors_file_metadata ,
22
20
_metadata_fn ,
23
21
DATA_OFFSETS_KEY ,
@@ -95,6 +93,9 @@ def _parse_input_metadata(
95
93
Raises:
96
94
ValueError: If no DCP custom metadata is found in a safetensors file
97
95
"""
96
+
97
+ from safetensors .torch import _getdtype # type: ignore[import]
98
+
98
99
# Dictionary to track the full size of each tensor across all shards
99
100
fqn_to_size_mapping : dict [str , tuple [list [int ], str ]] = {}
100
101
@@ -133,7 +134,7 @@ def _parse_input_metadata(
133
134
if fqn in output_data .fqn_data :
134
135
output_data .fqn_data [fqn ] = _FqnData (
135
136
shape_in_file = tensor_size ,
136
- dtype_size = torch .finfo (_get_dtype (dtype_str )).bits
137
+ dtype_size = torch .finfo (_getdtype (dtype_str )).bits
137
138
// 8 , # Convert bits to bytes
138
139
dtype_str = dtype_str ,
139
140
)
@@ -197,12 +198,7 @@ def _write_metadata(
197
198
output_data .metadata_size = f .tell ()
198
199
199
200
200
- def _read_tensor_data_mmap (
201
- file_path : str ,
202
- start_offset : int ,
203
- end_offset : int ,
204
- metadata_size : int ,
205
- ) -> bytes :
201
+ def _read_tensor_data (file_path : str , fqn : str ) -> bytes :
206
202
"""
207
203
Read tensor data from a safetensors file using memory mapping for efficiency.
208
204
@@ -215,12 +211,12 @@ def _read_tensor_data_mmap(
215
211
Returns:
216
212
Raw tensor data as bytes
217
213
"""
218
- # Use mmap for efficient access
219
- with open ( file_path , "rb" ) as f :
220
- with mmap . mmap ( f . fileno (), 0 , access = mmap . ACCESS_READ ) as mm :
221
- absolute_start = metadata_size + start_offset
222
- absolute_end = metadata_size + end_offset
223
- return bytes ( mm [ absolute_start : absolute_end ] )
214
+
215
+ from safetensors import safe_open # type: ignore[import]
216
+ from safetensors . torch import _tobytes # type: ignore[import]
217
+
218
+ with safe_open ( file_path , framework = "pt" ) as f :
219
+ return _tobytes ( f . get_tensor ( fqn ), fqn )
224
220
225
221
226
222
def _process_output_file (
@@ -257,21 +253,16 @@ def _process_output_file(
257
253
# Process each input safetensors file
258
254
for safetensors_file in input_files_data .keys ():
259
255
file_metadata = input_files_data [safetensors_file ].metadata
260
- input_metadata_size = input_files_data [safetensors_file ].metadata_size
261
256
262
257
if tensor_fqn not in file_metadata .keys ():
263
258
continue
264
259
265
260
metadata = file_metadata [tensor_fqn ]
266
261
267
- data_offsets = metadata [DATA_OFFSETS_KEY ]
268
-
269
262
# Use memory mapping to read tensor data efficiently
270
- data_to_write = _read_tensor_data_mmap (
263
+ data_to_write = _read_tensor_data (
271
264
safetensors_file ,
272
- data_offsets [0 ],
273
- data_offsets [1 ],
274
- input_metadata_size ,
265
+ tensor_fqn ,
275
266
)
276
267
277
268
# Get the offsets of this tensor shard within the full tensor
0 commit comments