From ea1a2d0f3d7d7ad85f3894260ce8997603935017 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 9 Jan 2025 11:37:55 +0100 Subject: [PATCH 1/9] Updating the dev number. (#558) --- bindings/python/Cargo.toml | 2 +- safetensors/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index c963fd64..a1dadeed 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "safetensors-python" -version = "0.5.1-dev.0" +version = "0.5.3-dev.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/safetensors/Cargo.toml b/safetensors/Cargo.toml index 02bcc2f0..037104a3 100644 --- a/safetensors/Cargo.toml +++ b/safetensors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "safetensors" -version = "0.5.1-dev.0" +version = "0.5.3-dev.0" edition = "2021" homepage = "https://github.com/huggingface/safetensors" repository = "https://github.com/huggingface/safetensors" From ee109c6098d4cb2573adcb334a36516703eb4d8c Mon Sep 17 00:00:00 2001 From: Asaf Karnieli Date: Tue, 4 Feb 2025 10:06:26 +0200 Subject: [PATCH 2/9] Add support for Intel Gaudi hpu accelerators (#566) * Add support for Intel Gaudi hpu accelerators * Fixing the `find_spec` dep. * Fixing unused import. --------- Co-authored-by: Nicolas Patry --- bindings/python/src/lib.rs | 5 ++++- bindings/python/tests/test_pt_comparison.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 2ecd9a7f..c9ac26c3 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -267,7 +267,8 @@ enum Device { Xpu(usize), Xla(usize), Mlu(usize), - /// User didn't specify acceletor, torch + Hpu, + /// User didn't specify accelerator, torch /// is responsible for choosing. Anonymous(usize), } @@ -296,6 +297,7 @@ impl<'source> FromPyObject<'source> for Device { "xpu" => Ok(Device::Xpu(0)), "xla" => Ok(Device::Xla(0)), "mlu" => Ok(Device::Mlu(0)), + "hpu" => Ok(Device::Hpu), name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda), name if name.starts_with("npu:") => parse_device(name).map(Device::Npu), name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu), @@ -327,6 +329,7 @@ impl<'py> IntoPyObject<'py> for Device { Device::Xpu(n) => format!("xpu:{n}").into_pyobject(py).map(|x| x.into_any()), Device::Xla(n) => format!("xla:{n}").into_pyobject(py).map(|x| x.into_any()), Device::Mlu(n) => format!("mlu:{n}").into_pyobject(py).map(|x| x.into_any()), + Device::Hpu => "hpu".into_pyobject(py).map(|x| x.into_any()), Device::Anonymous(n) => n.into_pyobject(py).map(|x| x.into_any()), } } diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index 6e569f50..fc616445 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -170,6 +170,24 @@ def test_npu(self): for k, v in reloaded.items(): self.assertTrue(torch.allclose(data[k], reloaded[k])) + def test_hpu(self): + # must be run to load torch with Intel Gaudi bindings + try: + import habana_frameworks.torch.core as htcore + except ImportError: + self.skipTest("HPU is not available") + + data = { + "test1": torch.zeros((2, 2), dtype=torch.float32).to("hpu"), + "test2": torch.zeros((2, 2), dtype=torch.float16).to("hpu"), + } + local = "./tests/data/out_safe_pt_mmap_small_hpu.safetensors" + save_file(data, local) + + reloaded = load_file(local, device="hpu") + for k, v in reloaded.items(): + self.assertTrue(torch.allclose(data[k], reloaded[k])) + @unittest.skipIf(not torch.cuda.is_available(), "Cuda is not available") def test_anonymous_accelerator(self): data = { From 581f43bf212a32aaf00021efbea7fca49620811b Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Tue, 4 Feb 2025 09:25:26 +0100 Subject: [PATCH 3/9] Restore compatibility with Rust 1.74 (#563) * Restore compatibility with Rust 1.74 PR #544 added support for `no_std` feature. The PR changed `std::error::Error` to `core::error::Error`. The `core::error` trait was stabilized in Rust 1.81, so the change bumped MSRV to 1.81. Before the Python package built with Rust 1.66 and the `safetensors` create with all features built with 1.74. This commit restores compatibility with Rust 1.74 for `std` builds: - `mixed_integer_ops` feature requires 1.66 - `half v2.4.1` requires 1.70 - `clap_lex v0.7.4` requires 1.74 I'm also adding `rust-version` to `Cargo.toml`, so cargo creates a backwards compatible `Cargo.lock`. By default, Cargo >= 1.83 creates a `v4` lock file, which is not compatible with Cargo < 1.78. Signed-off-by: Christian Heimes * Merging the test matrix. --------- Signed-off-by: Christian Heimes Co-authored-by: Nicolas Patry --- .github/workflows/rust.yml | 4 ++++ bindings/python/Cargo.toml | 1 + safetensors/Cargo.toml | 1 + safetensors/src/tensor.rs | 4 ++++ 4 files changed, 10 insertions(+) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e7ac0684..a1eb3e60 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -9,6 +9,10 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macOS-latest] + toolchain: [stable] + include: + - os: ubuntu-latest + toolchain: "1.74" defaults: run: working-directory: ./safetensors diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index a1dadeed..de7692bc 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -2,6 +2,7 @@ name = "safetensors-python" version = "0.5.3-dev.0" edition = "2021" +rust-version = "1.74" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] diff --git a/safetensors/Cargo.toml b/safetensors/Cargo.toml index 037104a3..dd303d65 100644 --- a/safetensors/Cargo.toml +++ b/safetensors/Cargo.toml @@ -2,6 +2,7 @@ name = "safetensors" version = "0.5.3-dev.0" edition = "2021" +rust-version = "1.74" homepage = "https://github.com/huggingface/safetensors" repository = "https://github.com/huggingface/safetensors" documentation = "https://docs.rs/safetensors/" diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index bee71782..42d59e04 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -64,8 +64,12 @@ impl core::fmt::Display for SafeTensorError { } } +#[cfg(not(feature = "std"))] impl core::error::Error for SafeTensorError {} +#[cfg(feature = "std")] +impl std::error::Error for SafeTensorError {} + struct PreparedData { n: u64, header_bytes: Vec, From fa833511664338bfc927fc02653ddb7d38d40be9 Mon Sep 17 00:00:00 2001 From: Oliver Ford Date: Tue, 4 Feb 2025 10:31:44 +0000 Subject: [PATCH 4/9] Return error on out of range index (#565) * Return error on out of range index Fix issue #560, return a SliceOutOfRange error if the stop value exceeds the available span. * Improve the fix. * Revert this change. * Adding unit test around invalid range * Checking for too many slices too. * Small cleanup. --------- Co-authored-by: Nicolas Patry --- .pre-commit-config.yaml | 2 - bindings/python/src/lib.rs | 8 +++- bindings/python/tests/test_simple.py | 16 +++++++- flake.nix | 2 +- safetensors/src/slice.rs | 61 +++++++++++++++++++++++++++- 5 files changed, 81 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06f77d27..1a28594b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,6 @@ repos: [ "--manifest-path", "safetensors/Cargo.toml", - "--all-features", "--all-targets", "--", "-Dwarnings", @@ -28,7 +27,6 @@ repos: [ "--manifest-path", "bindings/python/Cargo.toml", - "--all-features", "--all-targets", "--", "-Dwarnings", diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index c9ac26c3..517bac3e 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -798,8 +798,12 @@ struct Disp(Vec); impl fmt::Display for Disp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "[")?; - for item in &self.0 { - write!(f, "{item}")?; + for (i, item) in self.0.iter().enumerate() { + if i != self.0.len() - 1 { + write!(f, "{item}, ")?; + } else { + write!(f, "{item}")?; + } } write!(f, "]") } diff --git a/bindings/python/tests/test_simple.py b/bindings/python/tests/test_simple.py index 6930226e..8e840138 100644 --- a/bindings/python/tests/test_simple.py +++ b/bindings/python/tests/test_simple.py @@ -340,5 +340,19 @@ def test_numpy_slice(self): tensor = slice_[2:, 20] self.assertEqual( str(cm.exception), - "Error during slicing [2:20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }", + "Error during slicing [2:, 20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 20, dim_size: 5 }", + ) + + with self.assertRaises(SafetensorError) as cm: + tensor = slice_[:20] + self.assertEqual( + str(cm.exception), + "Error during slicing [:20] with shape [10, 5]: SliceOutOfRange { dim_index: 0, asked: 19, dim_size: 10 }", + ) + + with self.assertRaises(SafetensorError) as cm: + tensor = slice_[:, :20] + self.assertEqual( + str(cm.exception), + "Error during slicing [:, :20] with shape [10, 5]: SliceOutOfRange { dim_index: 1, asked: 19, dim_size: 5 }", ) diff --git a/flake.nix b/flake.nix index d9effbaf..a0cba1b7 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ postShellHook = '' unset SOURCE_DATE_EPOCH ''; - LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${pkgs.stdenv.cc.cc.lib}/lib"; + LD_LIBRARY_PATH = "$LD_LIBRARY_PATH:${pkgs.stdenv.cc.cc.lib}/lib:${pkgs.zlib}/lib"; }; } diff --git a/safetensors/src/slice.rs b/safetensors/src/slice.rs index 91087170..9f9ee8d3 100644 --- a/safetensors/src/slice.rs +++ b/safetensors/src/slice.rs @@ -7,6 +7,7 @@ use core::ops::{ /// Error representing invalid slicing attempt #[derive(Debug)] +#[cfg_attr(test, derive(Eq, PartialEq))] pub enum InvalidSlice { /// When the client asked for more slices than the tensors has dimensions TooManySlices, @@ -235,6 +236,7 @@ where /// Iterator used to return the bits of the overall tensor buffer /// when client asks for a slice of the original tensor. +#[cfg_attr(test, derive(Debug, Eq, PartialEq))] pub struct SliceIterator<'data> { view: &'data TensorView<'data>, indices: Vec<(usize, usize)>, @@ -284,10 +286,15 @@ impl<'data> SliceIterator<'data> { } TensorIndexer::Select(s) => (*s, *s + 1), }; - if start >= shape && stop > shape { + if start >= shape || stop > shape { + let asked = if start >= shape { + start + } else { + stop.saturating_sub(1) + }; return Err(InvalidSlice::SliceOutOfRange { dim_index: i, - asked: stop.saturating_sub(1), + asked, dim_size: shape, }); } @@ -573,4 +580,54 @@ mod tests { assert_eq!(iterator.next(), Some(&data[12..16])); assert_eq!(iterator.next(), None); } + + #[test] + fn test_invalid_range() { + let data: Vec = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0] + .into_iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + + let attn_0 = TensorView::new(Dtype::F32, vec![2, 3], &data).unwrap(); + + assert_eq!( + SliceIterator::new( + &attn_0, + &[ + TensorIndexer::Select(1), + TensorIndexer::Narrow(Bound::Included(1), Bound::Excluded(4)), + ], + ), + Err(InvalidSlice::SliceOutOfRange { + asked: 3, + dim_index: 1, + dim_size: 3, + }) + ); + assert_eq!( + SliceIterator::new( + &attn_0, + &[ + TensorIndexer::Select(1), + TensorIndexer::Narrow(Bound::Included(3), Bound::Excluded(2)), + ], + ), + Err(InvalidSlice::SliceOutOfRange { + asked: 3, + dim_index: 1, + dim_size: 3, + }) + ); + assert_eq!( + SliceIterator::new( + &attn_0, + &[ + TensorIndexer::Select(1), + TensorIndexer::Select(1), + TensorIndexer::Select(1), + ], + ), + Err(InvalidSlice::TooManySlices) + ); + } } From 4b3864c802d727be3cd67e5107a0f873d047ae69 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 6 Feb 2025 12:53:30 +0100 Subject: [PATCH 5/9] Demoing zero-copy save. (#567) * Demoing zero-copy save. * Fixing clippy issue. * Some cleanup. * Sdist doesn't require feature ? * Incorrect clean. * Clippy ? * Sanity check * Fixing the doc builder? * Using pre-commit for quality. * This should work. * Clippy variant. * pyfeature typo. * Bypassing the necessity for an env ? * Remove maturin. * BigEndian fix. * Only black. * We need to check both features. * ?? * Asking for readonly is not possible. * Before the build. * Fixing byte-endian? --- .github/workflows/build_documentation.yml | 2 + .github/workflows/build_pr_documentation.yml | 2 + .github/workflows/python-release.yml | 20 +-- .github/workflows/python.yml | 10 +- .pre-commit-config.yaml | 2 + Dockerfile.s390x.test | 1 + bindings/python/Cargo.toml | 6 +- bindings/python/py_src/safetensors/torch.py | 19 ++- bindings/python/src/lib.rs | 85 ++---------- bindings/python/src/view.rs | 137 +++++++++++++++++++ 10 files changed, 191 insertions(+), 93 deletions(-) create mode 100644 bindings/python/src/view.rs diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index 8b501942..a491bc15 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -11,6 +11,8 @@ on: jobs: build: uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + env: + MATURIN_PEP517_ARGS: "--features py311,pyo3/extension-module" with: commit_sha: ${{ github.sha }} package: safetensors diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index efafcd48..81214464 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -14,6 +14,8 @@ concurrency: jobs: build: uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + env: + MATURIN_PEP517_ARGS: "--features py311,pyo3/extension-module" with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} diff --git a/.github/workflows/python-release.yml b/.github/workflows/python-release.yml index 2cc96438..9399909b 100644 --- a/.github/workflows/python-release.yml +++ b/.github/workflows/python-release.yml @@ -23,6 +23,7 @@ jobs: runs-on: ${{ matrix.platform.runner }} strategy: matrix: + pyfeature: ["py38", "py311"] platform: - runner: ubuntu-latest target: x86_64 @@ -45,19 +46,20 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml + args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} sccache: 'true' manylinux: auto - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-linux-${{ matrix.platform.target }} + name: wheels-linux-${{ matrix.platform.target }}-${{ matrix.pyfeature }} path: dist musllinux: runs-on: ${{ matrix.platform.runner }} strategy: matrix: + pyfeature: ["py38", "py311"] platform: - runner: ubuntu-latest target: x86_64 @@ -76,19 +78,20 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml + args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} sccache: 'true' manylinux: musllinux_1_2 - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-musllinux-${{ matrix.platform.target }} + name: wheels-musllinux-${{ matrix.platform.target }}-${{ matrix.pyfeature }} path: dist windows: runs-on: ${{ matrix.platform.runner }} strategy: matrix: + pyfeature: ["py38", "py311"] platform: - runner: windows-latest target: x64 @@ -104,18 +107,19 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml + args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-windows-${{ matrix.platform.target }} + name: wheels-windows-${{ matrix.platform.target }}-${{ matrix.pyfeature }} path: dist macos: runs-on: ${{ matrix.platform.runner }} strategy: matrix: + pyfeature: ["py38", "py311"] platform: - runner: macos-13 target: x86_64 @@ -130,12 +134,12 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml + args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-macos-${{ matrix.platform.target }} + name: wheels-macos-${{ matrix.platform.target }}-${{ matrix.pyfeature }} path: dist sdist: diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 9fbbf2a7..0c5437c7 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -7,12 +7,14 @@ jobs: build_and_test: name: Check everything builds & tests runs-on: ${{ matrix.os }} + env: + MATURIN_PEP517_ARGS: "--features ${{ matrix.version.pyfeature }},pyo3/extension-module" strategy: matrix: os: [ubuntu-latest, macos-13, windows-latest] # Lowest and highest, no version specified so that # new releases get automatically tested against - version: [{torch: torch==1.10, python: "3.8"}, {torch: torch, python: "3.12"}] + version: [{torch: torch==1.10, python: "3.8", pyfeature: "py38"}, {torch: torch, python: "3.12", pyfeature: "py311"}] # TODO this would include macos ARM target. # however jax has an illegal instruction issue # that exists only in CI (probably difference in instruction support). @@ -52,14 +54,14 @@ jobs: run: cargo fmt -- --check - name: Lint with Clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: | + cargo clippy --features ${{ matrix.version.pyfeature }} -- -D warnings - name: Run Audit run: cargo audit -D warnings - name: Install run: | - pip install -U pip pip install .[numpy,tensorflow] pip install ${{ matrix.version.torch }} @@ -82,7 +84,7 @@ jobs: - name: Run tests run: | - cargo test + cargo test --features ${{ matrix.version.pyfeature }} pip install .[testing] pytest -sv tests/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1a28594b..a323e52c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,8 @@ repos: [ "--manifest-path", "bindings/python/Cargo.toml", + "--features", + "py311", "--all-targets", "--", "-Dwarnings", diff --git a/Dockerfile.s390x.test b/Dockerfile.s390x.test index d1dc7583..cb5bdba7 100644 --- a/Dockerfile.s390x.test +++ b/Dockerfile.s390x.test @@ -11,6 +11,7 @@ RUN /root/miniconda3/bin/pip install -U pip pytest COPY . . SHELL ["/bin/bash", "-c"] WORKDIR /safetensors/bindings/python/ +ENV MATURIN_PEP517_ARGS="--features py311,pyo3/extension-module" RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e . RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10' diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index de7692bc..fe8bbf44 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -9,8 +9,12 @@ rust-version = "1.74" name = "safetensors_rust" crate-type = ["cdylib"] +[features] +py38 = ["pyo3/abi3-py38"] +py311 = ["pyo3/abi3-py311"] + [dependencies] -pyo3 = { version = "0.23", features = ["abi3", "abi3-py38"] } +pyo3 = { version = "0.23", features = ["abi3"] } memmap2 = "0.9" serde_json = "1.0" diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 48532ea5..4476e754 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -128,7 +128,10 @@ def _remove_duplicate_names( def save_model( - model: torch.nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True + model: torch.nn.Module, + filename: str, + metadata: Optional[Dict[str, str]] = None, + force_contiguous: bool = True, ): """ Saves a given torch model to specified filename. @@ -174,7 +177,10 @@ def save_model( def load_model( - model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu" + model: torch.nn.Module, + filename: Union[str, os.PathLike], + strict: bool = True, + device: Union[str, int] = "cpu", ) -> Tuple[List[str], List[str]]: """ Loads a given filename onto a torch model. @@ -402,7 +408,7 @@ def _view2torch(safeview) -> Dict[str, torch.Tensor]: return result -def _tobytes(tensor: torch.Tensor, name: str) -> bytes: +def _tobytes(tensor: torch.Tensor, name: str) -> Union[memoryview, bytes]: if tensor.layout != torch.strided: raise ValueError( f"You are trying to save a sparse tensor: `{name}` which this library does not support." @@ -456,8 +462,11 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes: } npdtype = NPDTYPES[tensor.dtype] # Not in place as that would potentially modify a live running model - data = data.view(npdtype).byteswap(inplace=False) - return data.tobytes() + data = data.view(npdtype).byteswap(inplace=False).view(np.uint8) + if sys.version_info >= (3, 11): + return data.data + else: + return data.tobytes() def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]: diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 517bac3e..df3de7d3 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,5 +1,7 @@ #![deny(missing_docs)] //! Dummy doc +#[cfg(any(feature = "py38", feature = "py311"))] +mod view; use memmap2::{Mmap, MmapOptions}; use pyo3::exceptions::{PyException, PyFileNotFoundError}; use pyo3::prelude::*; @@ -10,14 +12,14 @@ use pyo3::Bound as PyBound; use pyo3::{intern, PyErr}; use safetensors::slice::TensorIndexer; use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorInfo, TensorView}; -use safetensors::View; -use std::borrow::Cow; use std::collections::HashMap; use std::fs::File; use std::iter::FromIterator; use std::ops::Bound; use std::path::PathBuf; use std::sync::Arc; +#[cfg(any(feature = "py38", feature = "py311"))] +use view::prepare; static TORCH_MODULE: GILOnceCell> = GILOnceCell::new(); static NUMPY_MODULE: GILOnceCell> = GILOnceCell::new(); @@ -25,79 +27,12 @@ static TENSORFLOW_MODULE: GILOnceCell> = GILOnceCell::new(); static FLAX_MODULE: GILOnceCell> = GILOnceCell::new(); static MLX_MODULE: GILOnceCell> = GILOnceCell::new(); -struct PyView<'a> { - shape: Vec, - dtype: Dtype, - data: PyBound<'a, PyBytes>, - data_len: usize, -} - -impl View for &PyView<'_> { - fn data(&self) -> std::borrow::Cow<[u8]> { - Cow::Borrowed(self.data.as_bytes()) - } - fn shape(&self) -> &[usize] { - &self.shape - } - fn dtype(&self) -> Dtype { - self.dtype - } - fn data_len(&self) -> usize { - self.data_len - } -} - -fn prepare(tensor_dict: HashMap>) -> PyResult> { - let mut tensors = HashMap::with_capacity(tensor_dict.len()); - for (tensor_name, tensor_desc) in &tensor_dict { - let shape: Vec = tensor_desc - .get_item("shape")? - .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))? - .extract()?; - let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| { - SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) - })?; - // Make sure it's extractable first. - let data: &[u8] = pydata.extract()?; - let data_len = data.len(); - let data: PyBound = pydata.extract()?; - let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| { - SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}")) - })?; - let dtype: String = pydtype.extract()?; - let dtype = match dtype.as_ref() { - "bool" => Dtype::BOOL, - "int8" => Dtype::I8, - "uint8" => Dtype::U8, - "int16" => Dtype::I16, - "uint16" => Dtype::U16, - "int32" => Dtype::I32, - "uint32" => Dtype::U32, - "int64" => Dtype::I64, - "uint64" => Dtype::U64, - "float16" => Dtype::F16, - "float32" => Dtype::F32, - "float64" => Dtype::F64, - "bfloat16" => Dtype::BF16, - "float8_e4m3fn" => Dtype::F8_E4M3, - "float8_e5m2" => Dtype::F8_E5M2, - dtype_str => { - return Err(SafetensorError::new_err(format!( - "dtype {dtype_str} is not covered", - ))); - } - }; - - let tensor = PyView { - shape, - dtype, - data, - data_len, - }; - tensors.insert(tensor_name.to_string(), tensor); - } - Ok(tensors) -} +#[cfg(not(any(feature = "py38", feature = "py311")))] +compile_error!( + "At least one python version must be enabled, use `maturin develop --features py311,pyo3/extension-module`" +); +#[cfg(all(feature = "py38", feature = "py311"))] +compile_error!("Only one python version must be enabled"); /// Serializes raw data. /// diff --git a/bindings/python/src/view.rs b/bindings/python/src/view.rs new file mode 100644 index 00000000..1ce22ca7 --- /dev/null +++ b/bindings/python/src/view.rs @@ -0,0 +1,137 @@ +use crate::SafetensorError; +#[cfg(feature = "py311")] +use pyo3::buffer::PyBuffer; +use pyo3::prelude::*; +#[cfg(feature = "py38")] +use pyo3::types::PyBytes; +use pyo3::types::PyDict; +use pyo3::Bound as PyBound; +use safetensors::{Dtype, View}; +use std::borrow::Cow; +use std::collections::HashMap; + +#[cfg(feature = "py38")] +pub struct PyView<'a> { + shape: Vec, + dtype: Dtype, + data: PyBound<'a, PyBytes>, + data_len: usize, +} + +#[cfg(feature = "py311")] +pub struct PyView<'a> { + shape: Vec, + dtype: Dtype, + data: PyBuffer, + data_len: usize, + // Kept to keep the GIL open while we hold the buffer + _py: Python<'a>, +} + +impl View for &PyView<'_> { + #[cfg(feature = "py38")] + fn data(&self) -> std::borrow::Cow<[u8]> { + Cow::Borrowed(self.data.as_bytes()) + } + #[cfg(feature = "py311")] + fn data(&self) -> std::borrow::Cow<[u8]> { + // We already checked this in the Python side. + assert!(self.data.is_c_contiguous()); + // XXX: Ideally we could have at least readonly tensors + // assert!(self.data.readonly()); + // SAFETY: + // This is actually totally unsafe, PyBuffer is not immutable and could be changed from + // under us. + // This is made safer because we're still hanging to the GIL while treating + // this structure + Cow::Borrowed(unsafe { + std::slice::from_raw_parts(self.data.buf_ptr() as *const u8, self.data.item_count()) + }) + } + fn shape(&self) -> &[usize] { + &self.shape + } + fn dtype(&self) -> Dtype { + self.dtype + } + fn data_len(&self) -> usize { + self.data_len + } +} + +pub fn prepare(tensor_dict: HashMap>) -> PyResult> { + let mut tensors = HashMap::with_capacity(tensor_dict.len()); + for (tensor_name, tensor_desc) in &tensor_dict { + let shape: Vec = tensor_desc + .get_item("shape")? + .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))? + .extract()?; + let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| { + SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) + })?; + + let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| { + SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}")) + })?; + let dtype: String = pydtype.extract()?; + let dtype = match dtype.as_ref() { + "bool" => Dtype::BOOL, + "int8" => Dtype::I8, + "uint8" => Dtype::U8, + "int16" => Dtype::I16, + "uint16" => Dtype::U16, + "int32" => Dtype::I32, + "uint32" => Dtype::U32, + "int64" => Dtype::I64, + "uint64" => Dtype::U64, + "float16" => Dtype::F16, + "float32" => Dtype::F32, + "float64" => Dtype::F64, + "bfloat16" => Dtype::BF16, + "float8_e4m3fn" => Dtype::F8_E4M3, + "float8_e5m2" => Dtype::F8_E5M2, + dtype_str => { + return Err(SafetensorError::new_err(format!( + "dtype {dtype_str} is not covered", + ))); + } + }; + + #[cfg(feature = "py311")] + let tensor = { + let data: PyBuffer = pydata.extract()?; + if !data.is_c_contiguous() { + return Err(SafetensorError::new_err("Python buffer is not contiguous")); + } + // XXX Ideally this would be true. + // if !data.readonly() { + // return Err(SafetensorError::new_err("Python buffer is not readonly")); + // } + let data_len = data.item_count(); + let py = pydata.py(); + PyView { + shape, + dtype, + data, + data_len, + _py: py, + } + }; + + #[cfg(feature = "py38")] + let tensor = { + let data: &[u8] = pydata.extract()?; + let data_len = data.len(); + let data: PyBound = pydata.extract()?; + PyView { + shape, + dtype, + data, + data_len, + } + }; + + tensors.insert(tensor_name.to_string(), tensor); + } + Ok(tensors) +} From 7bf65ad7d56be10331dd9c15b67d82d1c5f39cc0 Mon Sep 17 00:00:00 2001 From: spikedoanz <125920936+spikedoanz@users.noreply.github.com> Date: Sun, 23 Feb 2025 12:48:52 -0500 Subject: [PATCH 6/9] removed DS_Store (#573) --- bindings/.DS_Store | Bin 6148 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 bindings/.DS_Store diff --git a/bindings/.DS_Store b/bindings/.DS_Store deleted file mode 100644 index d8d5081b080ca349d71b6c23cb9c49a894541b73..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHK%}T>S5Z-NTn^J@v6nb3nS}^TN@e*Qv0V8@)sR=1I7_+5G4N?j@>kIiLK94iI zn?o?*O~lT??l(I>yO|HNKa4T%t)nqxHe<|!hR9KA5HvTsS|%8gt2rb~!y;RTi3}Si z`imy~_AXnpSC+E{tAGDTxJ=?Kcbre&YIpX!J*#I8tb2bU3qKF?b>;=DTO3_UnS_-d zgxB%BnAr#CGR=cHov&0v9L*u*?j}wnS$J}tMw!a>wZrOL{h2-7Y)&0#Bqk@5?MPJm zc)RUeheyX}m$RqjIhAjkNDiDU**93hJ1Di9Ui?*>$n+8HRZbO4NDL4I!~ij{dkmPf zz#8mc=~OW>Kn(oA0PYV08lr2k)Tp)&=Q6l=%ABp=%EPl(Kzo3Of^j7(AfVSS0WiRQWT>3lFHna#*I=m;XFpF From 579ddf974953ce28c0ae63b2f3b1ec441b7b6644 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 25 Feb 2025 12:11:12 +0100 Subject: [PATCH 7/9] hpu index (#578) --- bindings/python/src/lib.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index df3de7d3..1e6ba853 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -202,7 +202,7 @@ enum Device { Xpu(usize), Xla(usize), Mlu(usize), - Hpu, + Hpu(usize), /// User didn't specify accelerator, torch /// is responsible for choosing. Anonymous(usize), @@ -232,12 +232,13 @@ impl<'source> FromPyObject<'source> for Device { "xpu" => Ok(Device::Xpu(0)), "xla" => Ok(Device::Xla(0)), "mlu" => Ok(Device::Mlu(0)), - "hpu" => Ok(Device::Hpu), + "hpu" => Ok(Device::Hpu(0)), name if name.starts_with("cuda:") => parse_device(name).map(Device::Cuda), name if name.starts_with("npu:") => parse_device(name).map(Device::Npu), name if name.starts_with("xpu:") => parse_device(name).map(Device::Xpu), name if name.starts_with("xla:") => parse_device(name).map(Device::Xla), name if name.starts_with("mlu:") => parse_device(name).map(Device::Mlu), + name if name.starts_with("hpu:") => parse_device(name).map(Device::Hpu), name => Err(SafetensorError::new_err(format!( "device {name} is invalid" ))), @@ -264,7 +265,7 @@ impl<'py> IntoPyObject<'py> for Device { Device::Xpu(n) => format!("xpu:{n}").into_pyobject(py).map(|x| x.into_any()), Device::Xla(n) => format!("xla:{n}").into_pyobject(py).map(|x| x.into_any()), Device::Mlu(n) => format!("mlu:{n}").into_pyobject(py).map(|x| x.into_any()), - Device::Hpu => "hpu".into_pyobject(py).map(|x| x.into_any()), + Device::Hpu(n) => format!("hpu:{n}").into_pyobject(py).map(|x| x.into_any()), Device::Anonymous(n) => n.into_pyobject(py).map(|x| x.into_any()), } } From 80763e3318f16a07889dbf1ad061ad74b9d1b95b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Feb 2025 10:07:35 +0100 Subject: [PATCH 8/9] Revert "Demoing zero-copy save. (#567)" This reverts commit 4b3864c802d727be3cd67e5107a0f873d047ae69. --- .github/workflows/build_documentation.yml | 2 - .github/workflows/build_pr_documentation.yml | 2 - .github/workflows/python-release.yml | 20 ++- .github/workflows/python.yml | 10 +- .pre-commit-config.yaml | 2 - Dockerfile.s390x.test | 1 - bindings/python/Cargo.toml | 6 +- bindings/python/py_src/safetensors/torch.py | 19 +-- bindings/python/src/lib.rs | 85 ++++++++++-- bindings/python/src/view.rs | 137 ------------------- 10 files changed, 93 insertions(+), 191 deletions(-) delete mode 100644 bindings/python/src/view.rs diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index a491bc15..8b501942 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -11,8 +11,6 @@ on: jobs: build: uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main - env: - MATURIN_PEP517_ARGS: "--features py311,pyo3/extension-module" with: commit_sha: ${{ github.sha }} package: safetensors diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 81214464..efafcd48 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -14,8 +14,6 @@ concurrency: jobs: build: uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main - env: - MATURIN_PEP517_ARGS: "--features py311,pyo3/extension-module" with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} diff --git a/.github/workflows/python-release.yml b/.github/workflows/python-release.yml index 9399909b..2cc96438 100644 --- a/.github/workflows/python-release.yml +++ b/.github/workflows/python-release.yml @@ -23,7 +23,6 @@ jobs: runs-on: ${{ matrix.platform.runner }} strategy: matrix: - pyfeature: ["py38", "py311"] platform: - runner: ubuntu-latest target: x86_64 @@ -46,20 +45,19 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} + args: --release --out dist --manifest-path bindings/python/Cargo.toml sccache: 'true' manylinux: auto - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-linux-${{ matrix.platform.target }}-${{ matrix.pyfeature }} + name: wheels-linux-${{ matrix.platform.target }} path: dist musllinux: runs-on: ${{ matrix.platform.runner }} strategy: matrix: - pyfeature: ["py38", "py311"] platform: - runner: ubuntu-latest target: x86_64 @@ -78,20 +76,19 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} + args: --release --out dist --manifest-path bindings/python/Cargo.toml sccache: 'true' manylinux: musllinux_1_2 - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-musllinux-${{ matrix.platform.target }}-${{ matrix.pyfeature }} + name: wheels-musllinux-${{ matrix.platform.target }} path: dist windows: runs-on: ${{ matrix.platform.runner }} strategy: matrix: - pyfeature: ["py38", "py311"] platform: - runner: windows-latest target: x64 @@ -107,19 +104,18 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} + args: --release --out dist --manifest-path bindings/python/Cargo.toml sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-windows-${{ matrix.platform.target }}-${{ matrix.pyfeature }} + name: wheels-windows-${{ matrix.platform.target }} path: dist macos: runs-on: ${{ matrix.platform.runner }} strategy: matrix: - pyfeature: ["py38", "py311"] platform: - runner: macos-13 target: x86_64 @@ -134,12 +130,12 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.platform.target }} - args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }} + args: --release --out dist --manifest-path bindings/python/Cargo.toml sccache: 'true' - name: Upload wheels uses: actions/upload-artifact@v4 with: - name: wheels-macos-${{ matrix.platform.target }}-${{ matrix.pyfeature }} + name: wheels-macos-${{ matrix.platform.target }} path: dist sdist: diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 0c5437c7..9fbbf2a7 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -7,14 +7,12 @@ jobs: build_and_test: name: Check everything builds & tests runs-on: ${{ matrix.os }} - env: - MATURIN_PEP517_ARGS: "--features ${{ matrix.version.pyfeature }},pyo3/extension-module" strategy: matrix: os: [ubuntu-latest, macos-13, windows-latest] # Lowest and highest, no version specified so that # new releases get automatically tested against - version: [{torch: torch==1.10, python: "3.8", pyfeature: "py38"}, {torch: torch, python: "3.12", pyfeature: "py311"}] + version: [{torch: torch==1.10, python: "3.8"}, {torch: torch, python: "3.12"}] # TODO this would include macos ARM target. # however jax has an illegal instruction issue # that exists only in CI (probably difference in instruction support). @@ -54,14 +52,14 @@ jobs: run: cargo fmt -- --check - name: Lint with Clippy - run: | - cargo clippy --features ${{ matrix.version.pyfeature }} -- -D warnings + run: cargo clippy --all-targets --all-features -- -D warnings - name: Run Audit run: cargo audit -D warnings - name: Install run: | + pip install -U pip pip install .[numpy,tensorflow] pip install ${{ matrix.version.torch }} @@ -84,7 +82,7 @@ jobs: - name: Run tests run: | - cargo test --features ${{ matrix.version.pyfeature }} + cargo test pip install .[testing] pytest -sv tests/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a323e52c..1a28594b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,8 +27,6 @@ repos: [ "--manifest-path", "bindings/python/Cargo.toml", - "--features", - "py311", "--all-targets", "--", "-Dwarnings", diff --git a/Dockerfile.s390x.test b/Dockerfile.s390x.test index cb5bdba7..d1dc7583 100644 --- a/Dockerfile.s390x.test +++ b/Dockerfile.s390x.test @@ -11,7 +11,6 @@ RUN /root/miniconda3/bin/pip install -U pip pytest COPY . . SHELL ["/bin/bash", "-c"] WORKDIR /safetensors/bindings/python/ -ENV MATURIN_PEP517_ARGS="--features py311,pyo3/extension-module" RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e . RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py # RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10' diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index fe8bbf44..de7692bc 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -9,12 +9,8 @@ rust-version = "1.74" name = "safetensors_rust" crate-type = ["cdylib"] -[features] -py38 = ["pyo3/abi3-py38"] -py311 = ["pyo3/abi3-py311"] - [dependencies] -pyo3 = { version = "0.23", features = ["abi3"] } +pyo3 = { version = "0.23", features = ["abi3", "abi3-py38"] } memmap2 = "0.9" serde_json = "1.0" diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 4476e754..48532ea5 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -128,10 +128,7 @@ def _remove_duplicate_names( def save_model( - model: torch.nn.Module, - filename: str, - metadata: Optional[Dict[str, str]] = None, - force_contiguous: bool = True, + model: torch.nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True ): """ Saves a given torch model to specified filename. @@ -177,10 +174,7 @@ def save_model( def load_model( - model: torch.nn.Module, - filename: Union[str, os.PathLike], - strict: bool = True, - device: Union[str, int] = "cpu", + model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu" ) -> Tuple[List[str], List[str]]: """ Loads a given filename onto a torch model. @@ -408,7 +402,7 @@ def _view2torch(safeview) -> Dict[str, torch.Tensor]: return result -def _tobytes(tensor: torch.Tensor, name: str) -> Union[memoryview, bytes]: +def _tobytes(tensor: torch.Tensor, name: str) -> bytes: if tensor.layout != torch.strided: raise ValueError( f"You are trying to save a sparse tensor: `{name}` which this library does not support." @@ -462,11 +456,8 @@ def _tobytes(tensor: torch.Tensor, name: str) -> Union[memoryview, bytes]: } npdtype = NPDTYPES[tensor.dtype] # Not in place as that would potentially modify a live running model - data = data.view(npdtype).byteswap(inplace=False).view(np.uint8) - if sys.version_info >= (3, 11): - return data.data - else: - return data.tobytes() + data = data.view(npdtype).byteswap(inplace=False) + return data.tobytes() def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]: diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 1e6ba853..88661c41 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -1,7 +1,5 @@ #![deny(missing_docs)] //! Dummy doc -#[cfg(any(feature = "py38", feature = "py311"))] -mod view; use memmap2::{Mmap, MmapOptions}; use pyo3::exceptions::{PyException, PyFileNotFoundError}; use pyo3::prelude::*; @@ -12,14 +10,14 @@ use pyo3::Bound as PyBound; use pyo3::{intern, PyErr}; use safetensors::slice::TensorIndexer; use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorInfo, TensorView}; +use safetensors::View; +use std::borrow::Cow; use std::collections::HashMap; use std::fs::File; use std::iter::FromIterator; use std::ops::Bound; use std::path::PathBuf; use std::sync::Arc; -#[cfg(any(feature = "py38", feature = "py311"))] -use view::prepare; static TORCH_MODULE: GILOnceCell> = GILOnceCell::new(); static NUMPY_MODULE: GILOnceCell> = GILOnceCell::new(); @@ -27,12 +25,79 @@ static TENSORFLOW_MODULE: GILOnceCell> = GILOnceCell::new(); static FLAX_MODULE: GILOnceCell> = GILOnceCell::new(); static MLX_MODULE: GILOnceCell> = GILOnceCell::new(); -#[cfg(not(any(feature = "py38", feature = "py311")))] -compile_error!( - "At least one python version must be enabled, use `maturin develop --features py311,pyo3/extension-module`" -); -#[cfg(all(feature = "py38", feature = "py311"))] -compile_error!("Only one python version must be enabled"); +struct PyView<'a> { + shape: Vec, + dtype: Dtype, + data: PyBound<'a, PyBytes>, + data_len: usize, +} + +impl View for &PyView<'_> { + fn data(&self) -> std::borrow::Cow<[u8]> { + Cow::Borrowed(self.data.as_bytes()) + } + fn shape(&self) -> &[usize] { + &self.shape + } + fn dtype(&self) -> Dtype { + self.dtype + } + fn data_len(&self) -> usize { + self.data_len + } +} + +fn prepare(tensor_dict: HashMap>) -> PyResult> { + let mut tensors = HashMap::with_capacity(tensor_dict.len()); + for (tensor_name, tensor_desc) in &tensor_dict { + let shape: Vec = tensor_desc + .get_item("shape")? + .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))? + .extract()?; + let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| { + SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) + })?; + // Make sure it's extractable first. + let data: &[u8] = pydata.extract()?; + let data_len = data.len(); + let data: PyBound = pydata.extract()?; + let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| { + SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}")) + })?; + let dtype: String = pydtype.extract()?; + let dtype = match dtype.as_ref() { + "bool" => Dtype::BOOL, + "int8" => Dtype::I8, + "uint8" => Dtype::U8, + "int16" => Dtype::I16, + "uint16" => Dtype::U16, + "int32" => Dtype::I32, + "uint32" => Dtype::U32, + "int64" => Dtype::I64, + "uint64" => Dtype::U64, + "float16" => Dtype::F16, + "float32" => Dtype::F32, + "float64" => Dtype::F64, + "bfloat16" => Dtype::BF16, + "float8_e4m3fn" => Dtype::F8_E4M3, + "float8_e5m2" => Dtype::F8_E5M2, + dtype_str => { + return Err(SafetensorError::new_err(format!( + "dtype {dtype_str} is not covered", + ))); + } + }; + + let tensor = PyView { + shape, + dtype, + data, + data_len, + }; + tensors.insert(tensor_name.to_string(), tensor); + } + Ok(tensors) +} /// Serializes raw data. /// diff --git a/bindings/python/src/view.rs b/bindings/python/src/view.rs deleted file mode 100644 index 1ce22ca7..00000000 --- a/bindings/python/src/view.rs +++ /dev/null @@ -1,137 +0,0 @@ -use crate::SafetensorError; -#[cfg(feature = "py311")] -use pyo3::buffer::PyBuffer; -use pyo3::prelude::*; -#[cfg(feature = "py38")] -use pyo3::types::PyBytes; -use pyo3::types::PyDict; -use pyo3::Bound as PyBound; -use safetensors::{Dtype, View}; -use std::borrow::Cow; -use std::collections::HashMap; - -#[cfg(feature = "py38")] -pub struct PyView<'a> { - shape: Vec, - dtype: Dtype, - data: PyBound<'a, PyBytes>, - data_len: usize, -} - -#[cfg(feature = "py311")] -pub struct PyView<'a> { - shape: Vec, - dtype: Dtype, - data: PyBuffer, - data_len: usize, - // Kept to keep the GIL open while we hold the buffer - _py: Python<'a>, -} - -impl View for &PyView<'_> { - #[cfg(feature = "py38")] - fn data(&self) -> std::borrow::Cow<[u8]> { - Cow::Borrowed(self.data.as_bytes()) - } - #[cfg(feature = "py311")] - fn data(&self) -> std::borrow::Cow<[u8]> { - // We already checked this in the Python side. - assert!(self.data.is_c_contiguous()); - // XXX: Ideally we could have at least readonly tensors - // assert!(self.data.readonly()); - // SAFETY: - // This is actually totally unsafe, PyBuffer is not immutable and could be changed from - // under us. - // This is made safer because we're still hanging to the GIL while treating - // this structure - Cow::Borrowed(unsafe { - std::slice::from_raw_parts(self.data.buf_ptr() as *const u8, self.data.item_count()) - }) - } - fn shape(&self) -> &[usize] { - &self.shape - } - fn dtype(&self) -> Dtype { - self.dtype - } - fn data_len(&self) -> usize { - self.data_len - } -} - -pub fn prepare(tensor_dict: HashMap>) -> PyResult> { - let mut tensors = HashMap::with_capacity(tensor_dict.len()); - for (tensor_name, tensor_desc) in &tensor_dict { - let shape: Vec = tensor_desc - .get_item("shape")? - .ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))? - .extract()?; - let pydata: PyBound = tensor_desc.get_item("data")?.ok_or_else(|| { - SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}")) - })?; - - let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| { - SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}")) - })?; - let dtype: String = pydtype.extract()?; - let dtype = match dtype.as_ref() { - "bool" => Dtype::BOOL, - "int8" => Dtype::I8, - "uint8" => Dtype::U8, - "int16" => Dtype::I16, - "uint16" => Dtype::U16, - "int32" => Dtype::I32, - "uint32" => Dtype::U32, - "int64" => Dtype::I64, - "uint64" => Dtype::U64, - "float16" => Dtype::F16, - "float32" => Dtype::F32, - "float64" => Dtype::F64, - "bfloat16" => Dtype::BF16, - "float8_e4m3fn" => Dtype::F8_E4M3, - "float8_e5m2" => Dtype::F8_E5M2, - dtype_str => { - return Err(SafetensorError::new_err(format!( - "dtype {dtype_str} is not covered", - ))); - } - }; - - #[cfg(feature = "py311")] - let tensor = { - let data: PyBuffer = pydata.extract()?; - if !data.is_c_contiguous() { - return Err(SafetensorError::new_err("Python buffer is not contiguous")); - } - // XXX Ideally this would be true. - // if !data.readonly() { - // return Err(SafetensorError::new_err("Python buffer is not readonly")); - // } - let data_len = data.item_count(); - let py = pydata.py(); - PyView { - shape, - dtype, - data, - data_len, - _py: py, - } - }; - - #[cfg(feature = "py38")] - let tensor = { - let data: &[u8] = pydata.extract()?; - let data_len = data.len(); - let data: PyBound = pydata.extract()?; - PyView { - shape, - dtype, - data, - data_len, - } - }; - - tensors.insert(tensor_name.to_string(), tensor); - } - Ok(tensors) -} From 066773fe638fa62e79c704e25772520a00ea35e7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 26 Feb 2025 10:08:41 +0100 Subject: [PATCH 9/9] Patch release (without unsafe copy). --- bindings/python/Cargo.toml | 2 +- safetensors/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index de7692bc..1d4cfbd0 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "safetensors-python" -version = "0.5.3-dev.0" +version = "0.5.3" edition = "2021" rust-version = "1.74" diff --git a/safetensors/Cargo.toml b/safetensors/Cargo.toml index dd303d65..d54c969d 100644 --- a/safetensors/Cargo.toml +++ b/safetensors/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "safetensors" -version = "0.5.3-dev.0" +version = "0.5.3" edition = "2021" rust-version = "1.74" homepage = "https://github.com/huggingface/safetensors"