Skip to content

Commit c8cf6cf

Browse files
committed
Use safe_open in HF consolidation
Pull Request resolved: #159395 Use safe_open to read tensors for HF consolidation instead of reading the bytes with f.read() which is having worse performance. ghstack-source-id: 300135655 @exported-using-ghexport Differential Revision: [D79105491](https://our.internmc.facebook.com/intern/diff/D79105491/)
1 parent 34c3e5b commit c8cf6cf

File tree

3 files changed

+17
-46
lines changed

3 files changed

+17
-46
lines changed

torch/distributed/checkpoint/_consolidate_hf_safetensors.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import json
66
import logging
77
import math
8-
import mmap
98
import os
109
import struct
1110
import time
@@ -17,7 +16,6 @@
1716
from torch.distributed.checkpoint._hf_utils import (
1817
_gen_file_name,
1918
_get_dcp_custom_metadata,
20-
_get_dtype,
2119
_get_safetensors_file_metadata,
2220
_metadata_fn,
2321
DATA_OFFSETS_KEY,
@@ -95,6 +93,9 @@ def _parse_input_metadata(
9593
Raises:
9694
ValueError: If no DCP custom metadata is found in a safetensors file
9795
"""
96+
97+
from safetensors.torch import _getdtype # type: ignore[import]
98+
9899
# Dictionary to track the full size of each tensor across all shards
99100
fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {}
100101

@@ -133,7 +134,7 @@ def _parse_input_metadata(
133134
if fqn in output_data.fqn_data:
134135
output_data.fqn_data[fqn] = _FqnData(
135136
shape_in_file=tensor_size,
136-
dtype_size=torch.finfo(_get_dtype(dtype_str)).bits
137+
dtype_size=torch.finfo(_getdtype(dtype_str)).bits
137138
// 8, # Convert bits to bytes
138139
dtype_str=dtype_str,
139140
)
@@ -197,12 +198,7 @@ def _write_metadata(
197198
output_data.metadata_size = f.tell()
198199

199200

200-
def _read_tensor_data_mmap(
201-
file_path: str,
202-
start_offset: int,
203-
end_offset: int,
204-
metadata_size: int,
205-
) -> bytes:
201+
def _read_tensor_data(file_path: str, fqn: str) -> bytes:
206202
"""
207203
Read tensor data from a safetensors file using memory mapping for efficiency.
208204
@@ -215,12 +211,12 @@ def _read_tensor_data_mmap(
215211
Returns:
216212
Raw tensor data as bytes
217213
"""
218-
# Use mmap for efficient access
219-
with open(file_path, "rb") as f:
220-
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
221-
absolute_start = metadata_size + start_offset
222-
absolute_end = metadata_size + end_offset
223-
return bytes(mm[absolute_start:absolute_end])
214+
215+
from safetensors import safe_open # type: ignore[import]
216+
from safetensors.torch import _tobytes # type: ignore[import]
217+
218+
with safe_open(file_path, framework="pt") as f:
219+
return _tobytes(f.get_tensor(fqn), fqn)
224220

225221

226222
def _process_output_file(
@@ -257,21 +253,16 @@ def _process_output_file(
257253
# Process each input safetensors file
258254
for safetensors_file in input_files_data.keys():
259255
file_metadata = input_files_data[safetensors_file].metadata
260-
input_metadata_size = input_files_data[safetensors_file].metadata_size
261256

262257
if tensor_fqn not in file_metadata.keys():
263258
continue
264259

265260
metadata = file_metadata[tensor_fqn]
266261

267-
data_offsets = metadata[DATA_OFFSETS_KEY]
268-
269262
# Use memory mapping to read tensor data efficiently
270-
data_to_write = _read_tensor_data_mmap(
263+
data_to_write = _read_tensor_data(
271264
safetensors_file,
272-
data_offsets[0],
273-
data_offsets[1],
274-
input_metadata_size,
265+
tensor_fqn,
275266
)
276267

277268
# Get the offsets of this tensor shard within the full tensor

torch/distributed/checkpoint/_hf_utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,6 @@
2222
DTYPE_KEY = "dtype"
2323
DATA_OFFSETS_KEY = "data_offsets"
2424

25-
DTYPE_MAP = {
26-
"F16": torch.float16,
27-
"F32": torch.float32,
28-
"F64": torch.float64,
29-
"I8": torch.int8,
30-
"U8": torch.uint8,
31-
"I16": torch.int16,
32-
"I32": torch.int32,
33-
"I64": torch.int64,
34-
"BF16": torch.bfloat16,
35-
}
36-
3725
HF_DCP_VERSION: float = 1.0
3826
DCP_VERSION_KEY = "DCP_VERSION"
3927
DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO"
@@ -91,15 +79,6 @@ def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]:
9179
return (metadata, header_len + NUM_BYTES_FOR_HEADER_LEN)
9280

9381

94-
def _get_dtype(dtype_str: str) -> torch.dtype:
95-
try:
96-
dtype = DTYPE_MAP[dtype_str]
97-
except KeyError:
98-
dtype = torch.get_default_dtype()
99-
100-
return dtype
101-
102-
10382
def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]:
10483
if DEFAULT_EXTRA_METADATA_KEY in metadata:
10584
custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY]

torch/distributed/checkpoint/hf_storage.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
1414
from torch.distributed.checkpoint._hf_utils import (
1515
_gen_file_name,
16-
_get_dtype,
1716
_get_safetensors_file_metadata,
1817
_HFStorageInfo,
1918
_metadata_fn,
@@ -282,6 +281,8 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
282281
return fut
283282

284283
def read_metadata(self) -> Metadata:
284+
from safetensors.torch import _getdtype # type: ignore[import]
285+
285286
state_dict_metadata: dict[str, TensorStorageMetadata] = {}
286287
storage_data: dict[MetadataIndex, _HFStorageInfo] = {}
287288

@@ -314,7 +315,7 @@ def read_metadata(self) -> Metadata:
314315
if key not in state_dict_metadata:
315316
state_dict_metadata[key] = TensorStorageMetadata(
316317
properties=TensorProperties(
317-
dtype=_get_dtype(val[DTYPE_KEY])
318+
dtype=_getdtype(val[DTYPE_KEY])
318319
),
319320
size=torch.Size(
320321
[
@@ -354,7 +355,7 @@ def read_metadata(self) -> Metadata:
354355
offset=val[DATA_OFFSETS_KEY][0] + metadata_size,
355356
length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
356357
shape=torch.Size(val[SHAPE_KEY]),
357-
dtype=_get_dtype(val[DTYPE_KEY]),
358+
dtype=_getdtype(val[DTYPE_KEY]),
358359
)
359360

360361
metadata = Metadata(

0 commit comments

Comments
 (0)