8
8
9
9
import torch
10
10
11
- from huggingface_hub import CommitInfo , CommitOperationAdd , Discussion , HfApi , hf_hub_download
11
+ from huggingface_hub import (
12
+ CommitInfo ,
13
+ CommitOperationAdd ,
14
+ Discussion ,
15
+ HfApi ,
16
+ hf_hub_download ,
17
+ )
12
18
from huggingface_hub .file_download import repo_folder_name
13
19
from safetensors .torch import _find_shared_tensors , _is_complete , load_file , save_file
14
20
@@ -49,7 +55,9 @@ def _remove_duplicate_names(
49
55
shareds = _find_shared_tensors (state_dict )
50
56
to_remove = defaultdict (list )
51
57
for shared in shareds :
52
- complete_names = set ([name for name in shared if _is_complete (state_dict [name ])])
58
+ complete_names = set (
59
+ [name for name in shared if _is_complete (state_dict [name ])]
60
+ )
53
61
if not complete_names :
54
62
if len (shared ) == 1 :
55
63
# Force contiguous
@@ -81,14 +89,20 @@ def _remove_duplicate_names(
81
89
return to_remove
82
90
83
91
84
- def get_discard_names (model_id : str , revision : Optional [str ], folder : str , token : Optional [str ]) -> List [str ]:
92
+ def get_discard_names (
93
+ model_id : str , revision : Optional [str ], folder : str , token : Optional [str ]
94
+ ) -> List [str ]:
85
95
try :
86
96
import json
87
97
88
98
import transformers
89
99
90
100
config_filename = hf_hub_download (
91
- model_id , revision = revision , filename = "config.json" , token = token , cache_dir = folder
101
+ model_id ,
102
+ revision = revision ,
103
+ filename = "config.json" ,
104
+ token = token ,
105
+ cache_dir = folder ,
92
106
)
93
107
with open (config_filename , "r" ) as f :
94
108
config = json .load (f )
@@ -129,18 +143,29 @@ def rename(pt_filename: str) -> str:
129
143
130
144
131
145
def convert_multi (
132
- model_id : str , * , revision = Optional [str ], folder : str , token : Optional [str ], discard_names : List [str ]
146
+ model_id : str ,
147
+ * ,
148
+ revision = Optional [str ],
149
+ folder : str ,
150
+ token : Optional [str ],
151
+ discard_names : List [str ],
133
152
) -> ConversionResult :
134
153
filename = hf_hub_download (
135
- repo_id = model_id , revision = revision , filename = "pytorch_model.bin.index.json" , token = token , cache_dir = folder
154
+ repo_id = model_id ,
155
+ revision = revision ,
156
+ filename = "pytorch_model.bin.index.json" ,
157
+ token = token ,
158
+ cache_dir = folder ,
136
159
)
137
160
with open (filename , "r" ) as f :
138
161
data = json .load (f )
139
162
140
163
filenames = set (data ["weight_map" ].values ())
141
164
local_filenames = []
142
165
for filename in filenames :
143
- pt_filename = hf_hub_download (repo_id = model_id , filename = filename , token = token , cache_dir = folder )
166
+ pt_filename = hf_hub_download (
167
+ repo_id = model_id , filename = filename , token = token , cache_dir = folder
168
+ )
144
169
145
170
sf_filename = rename (pt_filename )
146
171
sf_filename = os .path .join (folder , sf_filename )
@@ -156,18 +181,28 @@ def convert_multi(
156
181
local_filenames .append (index )
157
182
158
183
operations = [
159
- CommitOperationAdd (path_in_repo = os .path .basename (local ), path_or_fileobj = local ) for local in local_filenames
184
+ CommitOperationAdd (path_in_repo = os .path .basename (local ), path_or_fileobj = local )
185
+ for local in local_filenames
160
186
]
161
187
errors : List [Tuple [str , "Exception" ]] = []
162
188
163
189
return operations , errors
164
190
165
191
166
192
def convert_single (
167
- model_id : str , * , revision : Optional [str ], folder : str , token : Optional [str ], discard_names : List [str ]
193
+ model_id : str ,
194
+ * ,
195
+ revision : Optional [str ],
196
+ folder : str ,
197
+ token : Optional [str ],
198
+ discard_names : List [str ],
168
199
) -> ConversionResult :
169
200
pt_filename = hf_hub_download (
170
- repo_id = model_id , revision = revision , filename = "pytorch_model.bin" , token = token , cache_dir = folder
201
+ repo_id = model_id ,
202
+ revision = revision ,
203
+ filename = "pytorch_model.bin" ,
204
+ token = token ,
205
+ cache_dir = folder ,
171
206
)
172
207
173
208
sf_name = "model.safetensors"
@@ -219,20 +254,30 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]])
219
254
sf_only = sf_set - pt_set
220
255
221
256
if pt_only :
222
- errors .append (f"{ key } : PT warnings contain { pt_only } which are not present in SF warnings" )
257
+ errors .append (
258
+ f"{ key } : PT warnings contain { pt_only } which are not present in SF warnings"
259
+ )
223
260
if sf_only :
224
- errors .append (f"{ key } : SF warnings contain { sf_only } which are not present in PT warnings" )
261
+ errors .append (
262
+ f"{ key } : SF warnings contain { sf_only } which are not present in PT warnings"
263
+ )
225
264
return "\n " .join (errors )
226
265
227
266
228
- def previous_pr (api : "HfApi" , model_id : str , pr_title : str , revision = Optional [str ]) -> Optional ["Discussion" ]:
267
+ def previous_pr (
268
+ api : "HfApi" , model_id : str , pr_title : str , revision = Optional [str ]
269
+ ) -> Optional ["Discussion" ]:
229
270
try :
230
271
revision_commit = api .model_info (model_id , revision = revision ).sha
231
272
discussions = api .get_repo_discussions (repo_id = model_id )
232
273
except Exception :
233
274
return None
234
275
for discussion in discussions :
235
- if discussion .status in {"open" , "closed" } and discussion .is_pull_request and discussion .title == pr_title :
276
+ if (
277
+ discussion .status in {"open" , "closed" }
278
+ and discussion .is_pull_request
279
+ and discussion .title == pr_title
280
+ ):
236
281
commits = api .list_repo_commits (model_id , revision = discussion .git_reference )
237
282
238
283
if revision_commit == commits [1 ].commit_id :
@@ -241,7 +286,12 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
241
286
242
287
243
288
def convert_generic (
244
- model_id : str , * , revision = Optional [str ], folder : str , filenames : Set [str ], token : Optional [str ]
289
+ model_id : str ,
290
+ * ,
291
+ revision = Optional [str ],
292
+ folder : str ,
293
+ filenames : Set [str ],
294
+ token : Optional [str ],
245
295
) -> ConversionResult :
246
296
operations = []
247
297
errors = []
@@ -251,7 +301,11 @@ def convert_generic(
251
301
prefix , ext = os .path .splitext (filename )
252
302
if ext in extensions :
253
303
pt_filename = hf_hub_download (
254
- model_id , revision = revision , filename = filename , token = token , cache_dir = folder
304
+ model_id ,
305
+ revision = revision ,
306
+ filename = filename ,
307
+ token = token ,
308
+ cache_dir = folder ,
255
309
)
256
310
dirname , raw_filename = os .path .split (filename )
257
311
if raw_filename == "pytorch_model.bin" :
@@ -263,7 +317,11 @@ def convert_generic(
263
317
sf_filename = os .path .join (folder , sf_in_repo )
264
318
try :
265
319
convert_file (pt_filename , sf_filename , discard_names = [])
266
- operations .append (CommitOperationAdd (path_in_repo = sf_in_repo , path_or_fileobj = sf_filename ))
320
+ operations .append (
321
+ CommitOperationAdd (
322
+ path_in_repo = sf_in_repo , path_or_fileobj = sf_filename
323
+ )
324
+ )
267
325
except Exception as e :
268
326
errors .append ((pt_filename , e ))
269
327
return operations , errors
@@ -285,28 +343,50 @@ def convert(
285
343
pr = previous_pr (api , model_id , pr_title , revision = revision )
286
344
287
345
library_name = getattr (info , "library_name" , None )
288
- if any (filename .endswith (".safetensors" ) for filename in filenames ) and not force :
289
- raise AlreadyExists (f"Model { model_id } is already converted, skipping.." )
346
+ if (
347
+ any (filename .endswith (".safetensors" ) for filename in filenames )
348
+ and not force
349
+ ):
350
+ raise AlreadyExists (
351
+ f"Model { model_id } is already converted, skipping.."
352
+ )
290
353
elif pr is not None and not force :
291
354
url = f"https://huggingface.co/{ model_id } /discussions/{ pr .num } "
292
355
new_pr = pr
293
- raise AlreadyExists (f"Model { model_id } already has an open PR check out { url } " )
356
+ raise AlreadyExists (
357
+ f"Model { model_id } already has an open PR check out { url } "
358
+ )
294
359
elif library_name == "transformers" :
295
-
296
- discard_names = get_discard_names (model_id , revision = revision , folder = folder , token = api .token )
360
+ discard_names = get_discard_names (
361
+ model_id , revision = revision , folder = folder , token = api .token
362
+ )
297
363
if "pytorch_model.bin" in filenames :
298
364
operations , errors = convert_single (
299
- model_id , revision = revision , folder = folder , token = api .token , discard_names = discard_names
365
+ model_id ,
366
+ revision = revision ,
367
+ folder = folder ,
368
+ token = api .token ,
369
+ discard_names = discard_names ,
300
370
)
301
371
elif "pytorch_model.bin.index.json" in filenames :
302
372
operations , errors = convert_multi (
303
- model_id , revision = revision , folder = folder , token = api .token , discard_names = discard_names
373
+ model_id ,
374
+ revision = revision ,
375
+ folder = folder ,
376
+ token = api .token ,
377
+ discard_names = discard_names ,
304
378
)
305
379
else :
306
- raise RuntimeError (f"Model { model_id } doesn't seem to be a valid pytorch model. Cannot convert" )
380
+ raise RuntimeError (
381
+ f"Model { model_id } doesn't seem to be a valid pytorch model. Cannot convert"
382
+ )
307
383
else :
308
384
operations , errors = convert_generic (
309
- model_id , revision = revision , folder = folder , filenames = filenames , token = api .token
385
+ model_id ,
386
+ revision = revision ,
387
+ folder = folder ,
388
+ filenames = filenames ,
389
+ token = api .token ,
310
390
)
311
391
312
392
if operations :
@@ -366,7 +446,9 @@ def convert(
366
446
" Continue [Y/n] ?"
367
447
)
368
448
if txt .lower () in {"" , "y" }:
369
- commit_info , errors = convert (api , model_id , revision = args .revision , force = args .force )
449
+ commit_info , errors = convert (
450
+ api , model_id , revision = args .revision , force = args .force
451
+ )
370
452
string = f"""
371
453
### Success 🔥
372
454
Yay! This model was successfully converted and a PR was open using your token, here:
@@ -375,7 +457,8 @@ def convert(
375
457
if errors :
376
458
string += "\n Errors during conversion:\n "
377
459
string += "\n " .join (
378
- f"Error while converting { filename } : { e } , skipped conversion" for filename , e in errors
460
+ f"Error while converting { filename } : { e } , skipped conversion"
461
+ for filename , e in errors
379
462
)
380
463
print (string )
381
464
else :
0 commit comments