@@ -48,13 +48,19 @@ export type Dtype =
48
48
| "F16"
49
49
| "F8_E4M3"
50
50
| "F8_E5M2"
51
+ | "E8M0"
52
+ | "F6_E3M2"
53
+ | "F6_E2M3"
54
+ | "F4"
55
+ | "FP4"
51
56
| "BF16"
52
57
| "I64"
53
58
| "I32"
54
59
| "I16"
55
60
| "I8"
56
61
| "U16"
57
62
| "U8"
63
+ | "UE8"
58
64
| "BOOL" ;
59
65
60
66
export interface TensorInfo {
@@ -92,6 +98,35 @@ export type SafetensorsParseFromRepo =
92
98
parameterTotal ?: number ;
93
99
} ;
94
100
101
+ /**
102
+ * Fetches and parses model config.json
103
+ */
104
+ async function fetchModelConfig (
105
+ params : {
106
+ repo : RepoDesignation ;
107
+ revision ?: string ;
108
+ hubUrl ?: string ;
109
+ fetch ?: typeof fetch ;
110
+ } & Partial < CredentialsParams >
111
+ ) : Promise < ModelConfig | null > {
112
+ try {
113
+ const configBlob = await downloadFile ( {
114
+ ...params ,
115
+ path : "config.json" ,
116
+ } ) ;
117
+
118
+ if ( ! configBlob ) {
119
+ return null ;
120
+ }
121
+
122
+ const config = JSON . parse ( await configBlob . text ( ) ) ;
123
+ return config as ModelConfig ;
124
+ } catch ( error ) {
125
+ // Config file might not exist or be inaccessible
126
+ return null ;
127
+ }
128
+ }
129
+
95
130
async function parseSingleFile (
96
131
path : string ,
97
132
params : {
@@ -252,6 +287,10 @@ export async function parseSafetensorsMetadata(
252
287
throw new TypeError ( "Only model repos should contain safetensors files." ) ;
253
288
}
254
289
290
+ // Fetch model config for quantization information
291
+ const modelConfig = params . computeParametersCount ? await fetchModelConfig ( params ) : null ;
292
+ const quantConfig = modelConfig ?. quantization_config ;
293
+
255
294
if (
256
295
( params . path && RE_SAFETENSORS_FILE . test ( params . path ) ) ||
257
296
( await fileExists ( { ...params , path : SAFETENSORS_FILE } ) )
@@ -262,17 +301,17 @@ export async function parseSafetensorsMetadata(
262
301
header,
263
302
...( params . computeParametersCount
264
303
? {
265
- parameterCount : computeNumOfParamsByDtypeSingleFile ( header ) ,
304
+ parameterCount : computeNumOfParamsByDtypeSingleFile ( header , quantConfig ) ,
266
305
parameterTotal :
267
306
/// shortcut: get param count directly from metadata
268
307
header . __metadata__ . total_parameters
269
308
? typeof header . __metadata__ . total_parameters === "number"
270
309
? header . __metadata__ . total_parameters
271
310
: typeof header . __metadata__ . total_parameters === "string"
272
- ? parseInt ( header . __metadata__ . total_parameters )
273
- : undefined
311
+ ? parseInt ( header . __metadata__ . total_parameters )
312
+ : undefined
274
313
: undefined ,
275
- }
314
+ }
276
315
: undefined ) ,
277
316
} ;
278
317
} else if (
@@ -289,41 +328,126 @@ export async function parseSafetensorsMetadata(
289
328
headers : shardedMap ,
290
329
...( params . computeParametersCount
291
330
? {
292
- parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap ) ,
331
+ parameterCount : computeNumOfParamsByDtypeSharded ( shardedMap , quantConfig ) ,
293
332
parameterTotal :
294
333
/// shortcut: get param count directly from metadata
295
334
index . metadata ?. total_parameters
296
335
? typeof index . metadata . total_parameters === "number"
297
336
? index . metadata . total_parameters
298
337
: typeof index . metadata . total_parameters === "string"
299
- ? parseInt ( index . metadata . total_parameters )
300
- : undefined
338
+ ? parseInt ( index . metadata . total_parameters )
339
+ : undefined
301
340
: undefined ,
302
- }
341
+ }
303
342
: undefined ) ,
304
343
} ;
305
344
} else {
306
345
throw new Error ( "model id does not seem to contain safetensors weights" ) ;
307
346
}
308
347
}
309
348
310
- function computeNumOfParamsByDtypeSingleFile ( header : SafetensorsFileHeader ) : Partial < Record < Dtype , number > > {
349
+ export interface QuantizationConfig {
350
+ quant_method ?: string ;
351
+ modules_to_not_convert ?: string [ ] ;
352
+ bits ?: number ;
353
+ load_in_4bit ?: boolean ;
354
+ load_in_8bit ?: boolean ;
355
+ }
356
+
357
+ export interface ModelConfig {
358
+ quantization_config ?: QuantizationConfig ;
359
+ }
360
+
361
+ /**
362
+ * Determines if a tensor is quantized based on quantization config and tensor name
363
+ */
364
+ function isQuantizedTensor ( tensorName : string , quantConfig ?: QuantizationConfig ) : boolean {
365
+ if ( ! quantConfig || ! quantConfig . modules_to_not_convert ) {
366
+ return false ;
367
+ }
368
+
369
+ for ( const pattern of quantConfig . modules_to_not_convert ) {
370
+ const regexPattern = pattern . replace ( / \* / g, ".*" ) ;
371
+ const regex = new RegExp ( regexPattern ) ;
372
+ if ( regex . test ( tensorName ) ) {
373
+ return false ;
374
+ }
375
+ }
376
+
377
+ return true ;
378
+ }
379
+
380
+ /**
381
+ * Gets the parameter multiplier for a quantized tensor based on quantization method
382
+ */
383
+ function getQuantizationMultiplier ( tensorName : string , dtype : Dtype , quantConfig ?: QuantizationConfig ) : number {
384
+ if ( ! quantConfig || ! isQuantizedTensor ( tensorName , quantConfig ) ) {
385
+ return 1 ;
386
+ }
387
+
388
+ switch ( quantConfig . quant_method ) {
389
+ case "mxfp4" :
390
+ if ( dtype === "U8" && tensorName . includes ( "_blocks" ) ) {
391
+ return 2 ;
392
+ }
393
+ return 1 ;
394
+
395
+ case "gptq" :
396
+ case "awq" :
397
+ if ( quantConfig . bits === 4 && dtype === "U8" ) {
398
+ return 2 ;
399
+ }
400
+ if ( quantConfig . bits === 2 && dtype === "U8" ) {
401
+ return 4 ;
402
+ }
403
+ return 1 ;
404
+
405
+ case "bitsandbytes" :
406
+ if ( quantConfig . load_in_4bit && dtype === "U8" ) {
407
+ return 2 ;
408
+ }
409
+ if ( quantConfig . load_in_8bit ) {
410
+ return 1 ;
411
+ }
412
+ return 1 ;
413
+
414
+ default :
415
+ if ( quantConfig . load_in_4bit && dtype === "U8" ) {
416
+ return 2 ;
417
+ }
418
+ if ( quantConfig . bits === 4 && dtype === "U8" ) {
419
+ return 2 ;
420
+ }
421
+ return 1 ;
422
+ }
423
+ }
424
+
425
+ function computeNumOfParamsByDtypeSingleFile (
426
+ header : SafetensorsFileHeader ,
427
+ quantConfig ?: QuantizationConfig
428
+ ) : Partial < Record < Dtype , number > > {
311
429
const counter : Partial < Record < Dtype , number > > = { } ;
312
430
const tensors = omit ( header , "__metadata__" ) ;
313
431
314
- for ( const [ , v ] of typedEntries ( tensors ) ) {
432
+ for ( const [ tensorName , v ] of typedEntries ( tensors ) ) {
315
433
if ( v . shape . length === 0 ) {
316
434
continue ;
317
435
}
318
- counter [ v . dtype ] = ( counter [ v . dtype ] ?? 0 ) + v . shape . reduce ( ( a , b ) => a * b ) ;
436
+
437
+ const elements = v . shape . reduce ( ( a , b ) => a * b ) ;
438
+ const multiplier = quantConfig ? getQuantizationMultiplier ( tensorName , v . dtype , quantConfig ) : 1 ;
439
+ counter [ v . dtype ] = ( counter [ v . dtype ] ?? 0 ) + elements * multiplier ;
319
440
}
320
441
return counter ;
321
442
}
322
443
323
- function computeNumOfParamsByDtypeSharded ( shardedMap : SafetensorsShardedHeaders ) : Partial < Record < Dtype , number > > {
444
+ function computeNumOfParamsByDtypeSharded (
445
+ shardedMap : SafetensorsShardedHeaders ,
446
+ quantConfig ?: QuantizationConfig
447
+ ) : Partial < Record < Dtype , number > > {
324
448
const counter : Partial < Record < Dtype , number > > = { } ;
325
449
for ( const header of Object . values ( shardedMap ) ) {
326
- for ( const [ k , v ] of typedEntries ( computeNumOfParamsByDtypeSingleFile ( header ) ) ) {
450
+ for ( const [ k , v ] of typedEntries ( computeNumOfParamsByDtypeSingleFile ( header , quantConfig ) ) ) {
327
451
counter [ k ] = ( counter [ k ] ?? 0 ) + ( v ?? 0 ) ;
328
452
}
329
453
}
0 commit comments