Skip to content

Remove usage of fsspec in HF consolidation script #159392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/ankitageorge/14/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 28 additions & 98 deletions torch/distributed/checkpoint/_consolidate_hf_safetensors.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
# pyre-strict

import concurrent.futures
import glob
import json
import logging
import math
import mmap
import os
import shutil
import struct
import tempfile
import time
from dataclasses import dataclass, field
from typing import Any, Optional

import fsspec # type: ignore[import-untyped]
from fsspec.core import url_to_fs # type: ignore[import-untyped]
from fsspec.implementations.local import LocalFileSystem # type: ignore[import-untyped]

import torch
from torch.distributed.checkpoint._hf_utils import (
_gen_file_name,
Expand All @@ -27,7 +22,6 @@
DATA_OFFSETS_KEY,
DEFAULT_EXTRA_METADATA_KEY,
DTYPE_KEY,
FILE_NAME,
SAVED_OFFSETS_KEY,
SHAPE_KEY,
SUFFIX,
Expand Down Expand Up @@ -145,7 +139,6 @@ def _parse_input_metadata(


def _write_metadata(
fs: fsspec.AbstractFileSystem,
output_files_data: dict[str, _OutputFileData],
) -> None:
"""
Expand All @@ -156,12 +149,11 @@ def _write_metadata(
field for each tensor in the output_files_data.

Args:
fs: Filesystem interface for file operations
output_files_data: Dictionary mapping output file paths to their metadata
"""
# Process each output file
for file_path, output_data in output_files_data.items():
with fs.open(file_path, "wb") as f:
with open(file_path, "wb") as f:
metadata = {}
curr_offset = 0

Expand Down Expand Up @@ -205,7 +197,6 @@ def _write_metadata(


def _read_tensor_data_mmap(
input_fs: fsspec.AbstractFileSystem,
file_path: str,
start_offset: int,
end_offset: int,
Expand All @@ -215,7 +206,6 @@ def _read_tensor_data_mmap(
Read tensor data from a safetensors file using memory mapping for efficiency.

Args:
input_fs: Filesystem interface for input file operations
file_path: Path to the safetensors file
start_offset: Start offset of tensor data within the data section
end_offset: End offset of tensor data within the data section
Expand All @@ -224,24 +214,15 @@ def _read_tensor_data_mmap(
Returns:
Raw tensor data as bytes
"""
# For local files, use mmap for efficient access
if isinstance(input_fs, LocalFileSystem):
# Local file - use mmap
with open(file_path, "rb") as f:
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
absolute_start = metadata_size + start_offset
absolute_end = metadata_size + end_offset
return bytes(mm[absolute_start:absolute_end])
else:
# Remote file - fall back to regular read
with input_fs.open(file_path, "rb") as f:
f.seek(metadata_size + start_offset)
return f.read(end_offset - start_offset)
# Use mmap for efficient access
with open(file_path, "rb") as f:
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put both context managers on the same with statement to reduce unnecessary indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't work because the mmap call needs to be nested in the open

absolute_start = metadata_size + start_offset
absolute_end = metadata_size + end_offset
return bytes(mm[absolute_start:absolute_end])


def _process_output_file(
input_fs: fsspec.AbstractFileSystem,
output_fs: fsspec.AbstractFileSystem,
output_file: str,
output_data: _OutputFileData,
input_files_data: dict[str, _InputFileData],
Expand All @@ -252,8 +233,6 @@ def _process_output_file(
This function is designed to be run in parallel for different output files.

Args:
input_fs: Filesystem interface for input file operations
output_fs: Filesystem interface for output file operations
output_file: Path to the output file
output_data: Metadata for the output file
input_files_data: Dictionary mapping input file paths to their metadata
Expand All @@ -275,7 +254,6 @@ def _process_output_file(

# Use memory mapping to read tensor data efficiently
data_to_write = _read_tensor_data_mmap(
input_fs,
safetensors_file,
data_offsets[0],
data_offsets[1],
Expand All @@ -291,7 +269,6 @@ def _process_output_file(

# Write this tensor shard to the appropriate position in the output file
_write_sub_tensor_to_file_optimized(
output_fs,
data_to_write,
fqn_data.dtype_size, # Size of each element in bytes
fqn_data.shape_in_file, # Full tensor shape
Expand All @@ -304,8 +281,6 @@ def _process_output_file(


def _write_data(
input_fs: fsspec.AbstractFileSystem,
output_fs: fsspec.AbstractFileSystem,
input_files_data: dict[str, _InputFileData],
output_files_data: dict[str, _OutputFileData],
num_threads: int = 1,
Expand All @@ -318,18 +293,14 @@ def _write_data(
the work is split across threads with each thread handling a different output file.

Args:
input_fs: Filesystem interface for input file operations
output_fs: Filesystem interface for output file operations
input_files_data: Dictionary mapping input file paths to their metadata
output_files_data: Dictionary mapping output file paths to their metadata
num_threads: Number of threads to use for parallel processing
"""
if num_threads <= 1 or len(output_files_data) <= 1:
# Sequential processing
for output_file, output_data in output_files_data.items():
_process_output_file(
input_fs, output_fs, output_file, output_data, input_files_data
)
_process_output_file(output_file, output_data, input_files_data)
else:
# Parallel processing with ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(
Expand All @@ -340,8 +311,6 @@ def _write_data(
futures.append(
executor.submit(
_process_output_file,
input_fs,
output_fs,
output_file,
output_data,
input_files_data,
Expand All @@ -359,7 +328,6 @@ def _write_data(


def _write_sub_tensor_to_file_optimized(
fs: fsspec.AbstractFileSystem,
sub_tensor_bytes: bytes,
element_size: int,
tensor_shape: list[int],
Expand All @@ -379,7 +347,6 @@ def _write_sub_tensor_to_file_optimized(
- Optimized chunks for other patterns

Args:
fs: Filesystem interface for file operations
sub_tensor_bytes: Raw tensor data as bytes
element_size: Size of each element in bytes
tensor_shape: Shape of the full tensor
Expand All @@ -403,7 +370,7 @@ def _write_sub_tensor_to_file_optimized(

total_elements = math.prod(sub_tensor_shape)

with fs.open(output_file_path, "r+b") as out_f:
with open(output_file_path, "r+b") as out_f:
elements_written = 0

while elements_written < total_elements:
Expand Down Expand Up @@ -491,10 +458,19 @@ def _calculate_max_contiguous_elements(


def _write_overall_metadata_file(
fs: fsspec.AbstractFileSystem,
output_dir: str,
output_files_data: dict[str, _OutputFileData],
) -> None:
"""
Write the overall metadata file that maps tensor names to their file locations.

This creates a model.safetensors.index.json file that HuggingFace models use
to locate tensors across multiple files.

Args:
output_dir: Directory where the metadata file will be written
output_files_data: Dictionary mapping output file paths to their metadata
"""
total_size = 0
weight_map = {}
for output_path, value in output_files_data.items():
Expand All @@ -507,32 +483,10 @@ def _write_overall_metadata_file(
metadata_to_write["weight_map"] = weight_map

metadata_path = os.path.join(output_dir, f"{_metadata_fn}")
with fs.open(metadata_path, "w") as metadata_file:
with open(metadata_path, "w") as metadata_file:
json.dump(metadata_to_write, metadata_file, indent=2)


def _upload_files_to_remote_fs(
local_fs: fsspec.AbstractFileSystem,
local_dir: str,
output_fs: fsspec.AbstractFileSystem,
output_dir: str,
) -> None:
"""
Uploads the consolidated files to the remote filesystem.
"""
for path in local_fs.ls(local_dir, detail=False):
file = os.path.basename(path)
model_str = FILE_NAME.split("-")[0]
# Upload only the consolidated files with full tensors or the metadata file.
# The check for file.startwith(model_str) is to ensure that we only upload
# the consolidated files in the format "model-0000n-of-0000m.safetensors"
# and not the files with sharded tensors.
if file.endswith(SUFFIX) and file.startswith(model_str) or file == _metadata_fn:
local_path = os.path.join(local_dir, file)
remote_path = os.path.join(output_dir, file)
output_fs.put_file(local_path, remote_path)


def consolidate_safetensors_files(
input_dir: str,
output_dir: str,
Expand Down Expand Up @@ -564,17 +518,6 @@ def consolidate_safetensors_files(
output_dir,
start_time,
)
# Create filesystem using fsspec for file operations
input_fs, _ = url_to_fs(input_dir)
output_fs, _ = url_to_fs(output_dir)

if not isinstance(output_fs, LocalFileSystem):
local_output_dir = tempfile.mkdtemp()
logger.info("Created temporary directory %s", local_output_dir)
local_output_fs, _ = url_to_fs(local_output_dir)
else:
local_output_fs = output_fs
local_output_dir = output_dir

# Initialize the output file structure
output_files_data: dict[str, _OutputFileData] = {}
Expand All @@ -583,7 +526,7 @@ def consolidate_safetensors_files(
for fqn, index in fqn_to_index_mapping.items():
# Generate names like "model-00001-of-00005.safetensors"
file_name = _gen_file_name(index, max(fqn_to_index_mapping.values()))
output_path = os.path.join(local_output_dir, file_name)
output_path = os.path.join(output_dir, file_name)

if output_path not in output_files_data:
output_files_data[output_path] = _OutputFileData(
Expand All @@ -594,19 +537,16 @@ def consolidate_safetensors_files(
else:
# If no mapping is provided, create a single output file
file_name = _gen_file_name(1, 1)
output_path = os.path.join(local_output_dir, file_name)
output_path = os.path.join(output_dir, file_name)
output_files_data[output_path] = _OutputFileData()

# Find all safetensors files in the input directory
safetensors_files = []
for file in input_fs.ls(input_dir, detail=False):
if file.endswith(SUFFIX):
safetensors_files.append(file)
safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}"))

# Read metadata from all input files
input_files_data: dict[str, _InputFileData] = {}
for safetensor_file in safetensors_files:
with input_fs.open(safetensor_file, "rb") as f:
with open(safetensor_file, "rb") as f:
metadata, size = _get_safetensors_file_metadata(f)
input_files_data[safetensor_file] = _InputFileData(
metadata_size=size, metadata=metadata
Expand All @@ -616,22 +556,12 @@ def consolidate_safetensors_files(
_parse_input_metadata(input_files_data, output_files_data)

# Step 2: Write metadata headers to output files
_write_metadata(local_output_fs, output_files_data)
_write_metadata(output_files_data)

# Step 3: Write actual tensor data from input files to output files
_write_data(
input_fs, local_output_fs, input_files_data, output_files_data, num_threads
)
_write_data(input_files_data, output_files_data, num_threads)

# Step 4: Write overall model.index.safetensors.json file with weight map
_write_overall_metadata_file(local_output_fs, local_output_dir, output_files_data)
_write_overall_metadata_file(output_dir, output_files_data)

logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time)

if local_output_dir != output_dir:
logger.info("Copying consolidated files to remote storage %s", output_dir)
_upload_files_to_remote_fs(
local_output_fs, local_output_dir, output_fs, output_dir
)
shutil.rmtree(local_output_dir)
logger.info("Deleting temporary directory %s", local_output_dir)
Loading