Skip to content

Commit b9ccdc6

Browse files
authored
GH action... once again. (#635)
* GH action... once again. * Using ruff instead of * Moving to ruff instead of black * Fixing the stub to not use black either. * . * Forgot to flush. * Old torch version dont have unsigned variants. * Fixing the handle tests. * Installing missing numpy. * Use numpy > 2 * ... * Yaml. * .. * ... * ..... * So annoying. * Fix uv lock. * Fix the fp4 tests. * Download hdf5 library. * .... * .. * .. * .. * .. * .. * .. * .. * . * .. * Fixing MLX implementation with newer versions. * Split macos into x86 legacy and latest aarch64. * .. * .. * Lock. * .. * . * Numpy version.
1 parent 7dfa63c commit b9ccdc6

File tree

18 files changed

+391
-255
lines changed

18 files changed

+391
-255
lines changed

.github/workflows/python.yml

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ jobs:
99
runs-on: ${{ matrix.os }}
1010
strategy:
1111
matrix:
12-
os: [ubuntu-latest, macos-13, windows-latest]
12+
os: [ubuntu-latest, windows-latest]
1313
# Lowest and highest, no version specified so that
1414
# new releases get automatically tested against
15-
version: [{torch: torch==1.10, python: "3.9", arch: "x64"}, {torch: torch, python: "3.12", arch: "x64"}]
15+
version: [{torch: torch==1.10, python: "3.9", arch: "x64", numpy: numpy==1.26.4}, {torch: torch, python: "3.12", arch: "x64", numpy: numpy}]
1616
# TODO this would include macos ARM target.
1717
# however jax has an illegal instruction issue
1818
# that exists only in CI (probably difference in instruction support).
@@ -26,7 +26,20 @@ jobs:
2626
version:
2727
torch: torch
2828
python: "3.13"
29+
numpy: numpy
2930
arch: "x64-freethreaded"
31+
- os: macos-13
32+
version:
33+
torch: torch==1.10
34+
numpy: "numpy==1.26"
35+
python: "3.9"
36+
arch: "x64"
37+
- os: macos-latest
38+
version:
39+
torch: torch
40+
python: "3.12"
41+
numpy: numpy
42+
arch: "arm64"
3043
defaults:
3144
run:
3245
working-directory: ./bindings/python
@@ -63,14 +76,14 @@ jobs:
6376
- name: Run Audit
6477
run: cargo audit -D warnings
6578

66-
- name: Install
67-
run: |
68-
pip install -U pip
69-
pip install .[numpy]
79+
# - name: Install
80+
# run: |
81+
# pip install -U pip
7082

7183
- name: Install (torch)
7284
if: matrix.version.arch != 'x64-freethreaded'
7385
run: |
86+
pip install ${{ matrix.version.numpy }}
7487
pip install ${{ matrix.version.torch }}
7588
shell: bash
7689

@@ -80,14 +93,22 @@ jobs:
8093
pip install ${{ matrix.version.torch }} --index-url https://download.pytorch.org/whl/cu126
8194
shell: bash
8295

96+
- name: Install (hdf5 non windows)
97+
if: matrix.os == 'ubuntu-latest' && matrix.version.arch != 'x64-freethreaded'
98+
run: |
99+
sudo apt-get update
100+
sudo apt-get install libhdf5-dev
101+
83102
- name: Install (tensorflow)
84103
if: matrix.version.arch != 'x64-freethreaded'
85104
run: |
86105
pip install .[tensorflow]
106+
# Force reinstall of numpy, tensorflow uses numpy 2 even on 3.9
107+
pip install ${{ matrix.version.numpy }}
87108
shell: bash
88109

89110
- name: Install (jax, flax)
90-
if: matrix.os != 'windows-latest' && matrix.version.arch != "x64-freethreaded"
111+
if: matrix.os != 'windows-latest' && matrix.version.arch != 'x64-freethreaded'
91112
run:
92113
pip install .[jax]
93114
shell: bash
@@ -101,14 +122,24 @@ jobs:
101122
- name: Check style
102123
run: |
103124
pip install .[quality]
104-
black --check --line-length 119 --target-version py35 py_src/safetensors tests
125+
ruff format --check .
105126
106127
- name: Run tests
128+
if: matrix.version.arch != 'x64-freethreaded'
107129
run: |
108130
cargo test
109-
pip install .[testing]
131+
pip install ".[testing]"
110132
pytest -sv tests/
111133
134+
- name: Run tests (freethreaded)
135+
if: matrix.version.arch == 'x64-freethreaded'
136+
run: |
137+
cargo test
138+
pip install ".[testingfree]"
139+
pip install pytest numpy
140+
pytest -sv tests/test_pt*
141+
pytest -sv tests/test_simple.py
142+
112143
test_s390x_big_endian:
113144
runs-on: ubuntu-latest
114145
permissions:

bindings/python/benches/test_pt.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,10 @@ def test_pt_sf_load_gpu(benchmark):
118118
assert torch.allclose(v, tv)
119119

120120

121-
@pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps")
121+
@pytest.mark.skipif(
122+
not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
123+
reason="requires mps",
124+
)
122125
def test_pt_pt_load_mps(benchmark):
123126
# benchmark something
124127
weights = create_gpt2(12)
@@ -133,7 +136,10 @@ def test_pt_pt_load_mps(benchmark):
133136
assert torch.allclose(v, tv)
134137

135138

136-
@pytest.mark.skipif(not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(), reason="requires mps")
139+
@pytest.mark.skipif(
140+
not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available(),
141+
reason="requires mps",
142+
)
137143
def test_pt_sf_load_mps(benchmark):
138144
# benchmark something
139145
weights = create_gpt2(12)

bindings/python/convert.py

Lines changed: 111 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
import torch
1010

11-
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
11+
from huggingface_hub import (
12+
CommitInfo,
13+
CommitOperationAdd,
14+
Discussion,
15+
HfApi,
16+
hf_hub_download,
17+
)
1218
from huggingface_hub.file_download import repo_folder_name
1319
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file
1420

@@ -49,7 +55,9 @@ def _remove_duplicate_names(
4955
shareds = _find_shared_tensors(state_dict)
5056
to_remove = defaultdict(list)
5157
for shared in shareds:
52-
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
58+
complete_names = set(
59+
[name for name in shared if _is_complete(state_dict[name])]
60+
)
5361
if not complete_names:
5462
if len(shared) == 1:
5563
# Force contiguous
@@ -81,14 +89,20 @@ def _remove_duplicate_names(
8189
return to_remove
8290

8391

84-
def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
92+
def get_discard_names(
93+
model_id: str, revision: Optional[str], folder: str, token: Optional[str]
94+
) -> List[str]:
8595
try:
8696
import json
8797

8898
import transformers
8999

90100
config_filename = hf_hub_download(
91-
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
101+
model_id,
102+
revision=revision,
103+
filename="config.json",
104+
token=token,
105+
cache_dir=folder,
92106
)
93107
with open(config_filename, "r") as f:
94108
config = json.load(f)
@@ -129,18 +143,29 @@ def rename(pt_filename: str) -> str:
129143

130144

131145
def convert_multi(
132-
model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
146+
model_id: str,
147+
*,
148+
revision=Optional[str],
149+
folder: str,
150+
token: Optional[str],
151+
discard_names: List[str],
133152
) -> ConversionResult:
134153
filename = hf_hub_download(
135-
repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
154+
repo_id=model_id,
155+
revision=revision,
156+
filename="pytorch_model.bin.index.json",
157+
token=token,
158+
cache_dir=folder,
136159
)
137160
with open(filename, "r") as f:
138161
data = json.load(f)
139162

140163
filenames = set(data["weight_map"].values())
141164
local_filenames = []
142165
for filename in filenames:
143-
pt_filename = hf_hub_download(repo_id=model_id, filename=filename, token=token, cache_dir=folder)
166+
pt_filename = hf_hub_download(
167+
repo_id=model_id, filename=filename, token=token, cache_dir=folder
168+
)
144169

145170
sf_filename = rename(pt_filename)
146171
sf_filename = os.path.join(folder, sf_filename)
@@ -156,18 +181,28 @@ def convert_multi(
156181
local_filenames.append(index)
157182

158183
operations = [
159-
CommitOperationAdd(path_in_repo=os.path.basename(local), path_or_fileobj=local) for local in local_filenames
184+
CommitOperationAdd(path_in_repo=os.path.basename(local), path_or_fileobj=local)
185+
for local in local_filenames
160186
]
161187
errors: List[Tuple[str, "Exception"]] = []
162188

163189
return operations, errors
164190

165191

166192
def convert_single(
167-
model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
193+
model_id: str,
194+
*,
195+
revision: Optional[str],
196+
folder: str,
197+
token: Optional[str],
198+
discard_names: List[str],
168199
) -> ConversionResult:
169200
pt_filename = hf_hub_download(
170-
repo_id=model_id, revision=revision, filename="pytorch_model.bin", token=token, cache_dir=folder
201+
repo_id=model_id,
202+
revision=revision,
203+
filename="pytorch_model.bin",
204+
token=token,
205+
cache_dir=folder,
171206
)
172207

173208
sf_name = "model.safetensors"
@@ -219,20 +254,30 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
219254
sf_only = sf_set - pt_set
220255

221256
if pt_only:
222-
errors.append(f"{key} : PT warnings contain {pt_only} which are not present in SF warnings")
257+
errors.append(
258+
f"{key} : PT warnings contain {pt_only} which are not present in SF warnings"
259+
)
223260
if sf_only:
224-
errors.append(f"{key} : SF warnings contain {sf_only} which are not present in PT warnings")
261+
errors.append(
262+
f"{key} : SF warnings contain {sf_only} which are not present in PT warnings"
263+
)
225264
return "\n".join(errors)
226265

227266

228-
def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]:
267+
def previous_pr(
268+
api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]
269+
) -> Optional["Discussion"]:
229270
try:
230271
revision_commit = api.model_info(model_id, revision=revision).sha
231272
discussions = api.get_repo_discussions(repo_id=model_id)
232273
except Exception:
233274
return None
234275
for discussion in discussions:
235-
if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
276+
if (
277+
discussion.status in {"open", "closed"}
278+
and discussion.is_pull_request
279+
and discussion.title == pr_title
280+
):
236281
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)
237282

238283
if revision_commit == commits[1].commit_id:
@@ -241,7 +286,12 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
241286

242287

243288
def convert_generic(
244-
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
289+
model_id: str,
290+
*,
291+
revision=Optional[str],
292+
folder: str,
293+
filenames: Set[str],
294+
token: Optional[str],
245295
) -> ConversionResult:
246296
operations = []
247297
errors = []
@@ -251,7 +301,11 @@ def convert_generic(
251301
prefix, ext = os.path.splitext(filename)
252302
if ext in extensions:
253303
pt_filename = hf_hub_download(
254-
model_id, revision=revision, filename=filename, token=token, cache_dir=folder
304+
model_id,
305+
revision=revision,
306+
filename=filename,
307+
token=token,
308+
cache_dir=folder,
255309
)
256310
dirname, raw_filename = os.path.split(filename)
257311
if raw_filename == "pytorch_model.bin":
@@ -263,7 +317,11 @@ def convert_generic(
263317
sf_filename = os.path.join(folder, sf_in_repo)
264318
try:
265319
convert_file(pt_filename, sf_filename, discard_names=[])
266-
operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename))
320+
operations.append(
321+
CommitOperationAdd(
322+
path_in_repo=sf_in_repo, path_or_fileobj=sf_filename
323+
)
324+
)
267325
except Exception as e:
268326
errors.append((pt_filename, e))
269327
return operations, errors
@@ -285,28 +343,50 @@ def convert(
285343
pr = previous_pr(api, model_id, pr_title, revision=revision)
286344

287345
library_name = getattr(info, "library_name", None)
288-
if any(filename.endswith(".safetensors") for filename in filenames) and not force:
289-
raise AlreadyExists(f"Model {model_id} is already converted, skipping..")
346+
if (
347+
any(filename.endswith(".safetensors") for filename in filenames)
348+
and not force
349+
):
350+
raise AlreadyExists(
351+
f"Model {model_id} is already converted, skipping.."
352+
)
290353
elif pr is not None and not force:
291354
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
292355
new_pr = pr
293-
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
356+
raise AlreadyExists(
357+
f"Model {model_id} already has an open PR check out {url}"
358+
)
294359
elif library_name == "transformers":
295-
296-
discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
360+
discard_names = get_discard_names(
361+
model_id, revision=revision, folder=folder, token=api.token
362+
)
297363
if "pytorch_model.bin" in filenames:
298364
operations, errors = convert_single(
299-
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
365+
model_id,
366+
revision=revision,
367+
folder=folder,
368+
token=api.token,
369+
discard_names=discard_names,
300370
)
301371
elif "pytorch_model.bin.index.json" in filenames:
302372
operations, errors = convert_multi(
303-
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
373+
model_id,
374+
revision=revision,
375+
folder=folder,
376+
token=api.token,
377+
discard_names=discard_names,
304378
)
305379
else:
306-
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
380+
raise RuntimeError(
381+
f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert"
382+
)
307383
else:
308384
operations, errors = convert_generic(
309-
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
385+
model_id,
386+
revision=revision,
387+
folder=folder,
388+
filenames=filenames,
389+
token=api.token,
310390
)
311391

312392
if operations:
@@ -366,7 +446,9 @@ def convert(
366446
" Continue [Y/n] ?"
367447
)
368448
if txt.lower() in {"", "y"}:
369-
commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force)
449+
commit_info, errors = convert(
450+
api, model_id, revision=args.revision, force=args.force
451+
)
370452
string = f"""
371453
### Success 🔥
372454
Yay! This model was successfully converted and a PR was open using your token, here:
@@ -375,7 +457,8 @@ def convert(
375457
if errors:
376458
string += "\nErrors during conversion:\n"
377459
string += "\n".join(
378-
f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors
460+
f"Error while converting {filename}: {e}, skipped conversion"
461+
for filename, e in errors
379462
)
380463
print(string)
381464
else:

0 commit comments

Comments
 (0)