diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 8577180e9f89..a0d205f80821 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -1,33 +1,26 @@ # pyre-strict import concurrent.futures +import glob import json import logging import math import mmap import os -import shutil import struct -import tempfile import time from dataclasses import dataclass, field from typing import Any, Optional -import fsspec # type: ignore[import-untyped] -from fsspec.core import url_to_fs # type: ignore[import-untyped] -from fsspec.implementations.local import LocalFileSystem # type: ignore[import-untyped] - import torch from torch.distributed.checkpoint._hf_utils import ( _gen_file_name, _get_dcp_custom_metadata, - _get_dtype, _get_safetensors_file_metadata, _metadata_fn, DATA_OFFSETS_KEY, DEFAULT_EXTRA_METADATA_KEY, DTYPE_KEY, - FILE_NAME, SAVED_OFFSETS_KEY, SHAPE_KEY, SUFFIX, @@ -100,6 +93,9 @@ def _parse_input_metadata( Raises: ValueError: If no DCP custom metadata is found in a safetensors file """ + + from safetensors.torch import _getdtype # type: ignore[import] + # Dictionary to track the full size of each tensor across all shards fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} @@ -138,14 +134,13 @@ def _parse_input_metadata( if fqn in output_data.fqn_data or len(output_files_data) == 1: output_data.fqn_data[fqn] = _FqnData( shape_in_file=tensor_size, - dtype_size=torch.finfo(_get_dtype(dtype_str)).bits + dtype_size=torch.finfo(_getdtype(dtype_str)).bits // 8, # Convert bits to bytes dtype_str=dtype_str, ) def _write_metadata( - fs: fsspec.AbstractFileSystem, output_files_data: dict[str, _OutputFileData], ) -> None: """ @@ -156,12 +151,11 @@ def _write_metadata( field for each tensor in the output_files_data. Args: - fs: Filesystem interface for file operations output_files_data: Dictionary mapping output file paths to their metadata """ # Process each output file for file_path, output_data in output_files_data.items(): - with fs.open(file_path, "wb") as f: + with open(file_path, "wb") as f: metadata = {} curr_offset = 0 @@ -205,7 +199,6 @@ def _write_metadata( def _read_tensor_data_mmap( - input_fs: fsspec.AbstractFileSystem, file_path: str, start_offset: int, end_offset: int, @@ -215,7 +208,6 @@ def _read_tensor_data_mmap( Read tensor data from a safetensors file using memory mapping for efficiency. Args: - input_fs: Filesystem interface for input file operations file_path: Path to the safetensors file start_offset: Start offset of tensor data within the data section end_offset: End offset of tensor data within the data section @@ -224,24 +216,15 @@ def _read_tensor_data_mmap( Returns: Raw tensor data as bytes """ - # For local files, use mmap for efficient access - if isinstance(input_fs, LocalFileSystem): - # Local file - use mmap - with open(file_path, "rb") as f: - with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: - absolute_start = metadata_size + start_offset - absolute_end = metadata_size + end_offset - return bytes(mm[absolute_start:absolute_end]) - else: - # Remote file - fall back to regular read - with input_fs.open(file_path, "rb") as f: - f.seek(metadata_size + start_offset) - return f.read(end_offset - start_offset) + # Use mmap for efficient access + with open(file_path, "rb") as f: + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm: + absolute_start = metadata_size + start_offset + absolute_end = metadata_size + end_offset + return bytes(mm[absolute_start:absolute_end]) def _process_output_file( - input_fs: fsspec.AbstractFileSystem, - output_fs: fsspec.AbstractFileSystem, output_file: str, output_data: _OutputFileData, input_files_data: dict[str, _InputFileData], @@ -252,8 +235,6 @@ def _process_output_file( This function is designed to be run in parallel for different output files. Args: - input_fs: Filesystem interface for input file operations - output_fs: Filesystem interface for output file operations output_file: Path to the output file output_data: Metadata for the output file input_files_data: Dictionary mapping input file paths to their metadata @@ -275,7 +256,6 @@ def _process_output_file( # Use memory mapping to read tensor data efficiently data_to_write = _read_tensor_data_mmap( - input_fs, safetensors_file, data_offsets[0], data_offsets[1], @@ -291,7 +271,6 @@ def _process_output_file( # Write this tensor shard to the appropriate position in the output file _write_sub_tensor_to_file_optimized( - output_fs, data_to_write, fqn_data.dtype_size, # Size of each element in bytes fqn_data.shape_in_file, # Full tensor shape @@ -304,8 +283,6 @@ def _process_output_file( def _write_data( - input_fs: fsspec.AbstractFileSystem, - output_fs: fsspec.AbstractFileSystem, input_files_data: dict[str, _InputFileData], output_files_data: dict[str, _OutputFileData], num_threads: int = 1, @@ -318,8 +295,6 @@ def _write_data( the work is split across threads with each thread handling a different output file. Args: - input_fs: Filesystem interface for input file operations - output_fs: Filesystem interface for output file operations input_files_data: Dictionary mapping input file paths to their metadata output_files_data: Dictionary mapping output file paths to their metadata num_threads: Number of threads to use for parallel processing @@ -327,9 +302,7 @@ def _write_data( if num_threads <= 1 or len(output_files_data) <= 1: # Sequential processing for output_file, output_data in output_files_data.items(): - _process_output_file( - input_fs, output_fs, output_file, output_data, input_files_data - ) + _process_output_file(output_file, output_data, input_files_data) else: # Parallel processing with ThreadPoolExecutor with concurrent.futures.ThreadPoolExecutor( @@ -340,8 +313,6 @@ def _write_data( futures.append( executor.submit( _process_output_file, - input_fs, - output_fs, output_file, output_data, input_files_data, @@ -359,7 +330,6 @@ def _write_data( def _write_sub_tensor_to_file_optimized( - fs: fsspec.AbstractFileSystem, sub_tensor_bytes: bytes, element_size: int, tensor_shape: list[int], @@ -379,7 +349,6 @@ def _write_sub_tensor_to_file_optimized( - Optimized chunks for other patterns Args: - fs: Filesystem interface for file operations sub_tensor_bytes: Raw tensor data as bytes element_size: Size of each element in bytes tensor_shape: Shape of the full tensor @@ -403,7 +372,7 @@ def _write_sub_tensor_to_file_optimized( total_elements = math.prod(sub_tensor_shape) - with fs.open(output_file_path, "r+b") as out_f: + with open(output_file_path, "r+b") as out_f: elements_written = 0 while elements_written < total_elements: @@ -524,10 +493,19 @@ def _calculate_max_contiguous_elements( def _write_overall_metadata_file( - fs: fsspec.AbstractFileSystem, output_dir: str, output_files_data: dict[str, _OutputFileData], ) -> None: + """ + Write the overall metadata file that maps tensor names to their file locations. + + This creates a model.safetensors.index.json file that HuggingFace models use + to locate tensors across multiple files. + + Args: + output_dir: Directory where the metadata file will be written + output_files_data: Dictionary mapping output file paths to their metadata + """ total_size = 0 weight_map = {} for output_path, value in output_files_data.items(): @@ -540,32 +518,10 @@ def _write_overall_metadata_file( metadata_to_write["weight_map"] = weight_map metadata_path = os.path.join(output_dir, f"{_metadata_fn}") - with fs.open(metadata_path, "w") as metadata_file: + with open(metadata_path, "w") as metadata_file: json.dump(metadata_to_write, metadata_file, indent=2) -def _upload_files_to_remote_fs( - local_fs: fsspec.AbstractFileSystem, - local_dir: str, - output_fs: fsspec.AbstractFileSystem, - output_dir: str, -) -> None: - """ - Uploads the consolidated files to the remote filesystem. - """ - for path in local_fs.ls(local_dir, detail=False): - file = os.path.basename(path) - model_str = FILE_NAME.split("-")[0] - # Upload only the consolidated files with full tensors or the metadata file. - # The check for file.startwith(model_str) is to ensure that we only upload - # the consolidated files in the format "model-0000n-of-0000m.safetensors" - # and not the files with sharded tensors. - if file.endswith(SUFFIX) and file.startswith(model_str) or file == _metadata_fn: - local_path = os.path.join(local_dir, file) - remote_path = os.path.join(output_dir, file) - output_fs.put_file(local_path, remote_path) - - def consolidate_safetensors_files( input_dir: str, output_dir: str, @@ -597,17 +553,6 @@ def consolidate_safetensors_files( output_dir, start_time, ) - # Create filesystem using fsspec for file operations - input_fs, _ = url_to_fs(input_dir) - output_fs, _ = url_to_fs(output_dir) - - if not isinstance(output_fs, LocalFileSystem): - local_output_dir = tempfile.mkdtemp() - logger.info("Created temporary directory %s", local_output_dir) - local_output_fs, _ = url_to_fs(local_output_dir) - else: - local_output_fs = output_fs - local_output_dir = output_dir # Initialize the output file structure output_files_data: dict[str, _OutputFileData] = {} @@ -616,7 +561,7 @@ def consolidate_safetensors_files( for fqn, index in fqn_to_index_mapping.items(): # Generate names like "model-00001-of-00005.safetensors" file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) - output_path = os.path.join(local_output_dir, file_name) + output_path = os.path.join(output_dir, file_name) if output_path not in output_files_data: output_files_data[output_path] = _OutputFileData( @@ -627,19 +572,16 @@ def consolidate_safetensors_files( else: # If no mapping is provided, create a single output file file_name = _gen_file_name(1, 1) - output_path = os.path.join(local_output_dir, file_name) + output_path = os.path.join(output_dir, file_name) output_files_data[output_path] = _OutputFileData() # Find all safetensors files in the input directory - safetensors_files = [] - for file in input_fs.ls(input_dir, detail=False): - if file.endswith(SUFFIX): - safetensors_files.append(file) + safetensors_files = glob.glob(os.path.join(input_dir, f"*{SUFFIX}")) # Read metadata from all input files input_files_data: dict[str, _InputFileData] = {} for safetensor_file in safetensors_files: - with input_fs.open(safetensor_file, "rb") as f: + with open(safetensor_file, "rb") as f: metadata, size = _get_safetensors_file_metadata(f) input_files_data[safetensor_file] = _InputFileData( metadata_size=size, metadata=metadata @@ -649,22 +591,12 @@ def consolidate_safetensors_files( _parse_input_metadata(input_files_data, output_files_data) # Step 2: Write metadata headers to output files - _write_metadata(local_output_fs, output_files_data) + _write_metadata(output_files_data) # Step 3: Write actual tensor data from input files to output files - _write_data( - input_fs, local_output_fs, input_files_data, output_files_data, num_threads - ) + _write_data(input_files_data, output_files_data, num_threads) # Step 4: Write overall model.index.safetensors.json file with weight map - _write_overall_metadata_file(local_output_fs, local_output_dir, output_files_data) + _write_overall_metadata_file(output_dir, output_files_data) logger.info("Done consolidating. Took %.2f secs.", time.time() - start_time) - - if local_output_dir != output_dir: - logger.info("Copying consolidated files to remote storage %s", output_dir) - _upload_files_to_remote_fs( - local_output_fs, local_output_dir, output_fs, output_dir - ) - shutil.rmtree(local_output_dir) - logger.info("Deleting temporary directory %s", local_output_dir) diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index 6b36e619f7ce..542203ed82cf 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -47,9 +47,7 @@ class HuggingFaceStorageWriter(FileSystemWriter): """ - A writer that writes to a huggingface repository in the huggingface format. - Uses Fsspec back-end to communicate with back-end storage. - Fsspec registration of the storage solution is required. + A writer that writes to storage in the huggingface safetensors format. """ def __init__( @@ -196,9 +194,7 @@ def metadata_path(self) -> str: class HuggingFaceStorageReader(FileSystemReader): """ - A reader that reads from a huggingface repository in the huggingface format. - Uses in Fsspec back-end to communicate with storage. - Fsspec registration of the storage solution is required. + A reader that reads a checkpoint in the huggingface safetensors format. """ def __init__(self, path: str) -> None: