Skip to content

Commit e126515

Browse files
committed
[safetensors] parameters count based on quantization config
1 parent eebfb1f commit e126515

File tree

2 files changed

+272
-13
lines changed

2 files changed

+272
-13
lines changed

packages/hub/src/lib/parse-safetensors-metadata.spec.ts

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,139 @@ describe("parseSafetensorsMetadata", () => {
142142
assert.strictEqual(safetensorsShardFileInfo?.shard, "00005");
143143
assert.strictEqual(safetensorsShardFileInfo?.total, "00072");
144144
});
145+
146+
it("should support sub-byte data types", async () => {
147+
const newDataTypes: Array<"F4" | "F6_E2M3" | "F6_E3M2" | "E8M0"> = ["F4", "F6_E2M3", "F6_E3M2", "E8M0"];
148+
149+
for (const dtype of newDataTypes) {
150+
const tensorInfo = {
151+
dtype,
152+
shape: [1, 2],
153+
data_offsets: [0, 1] as [number, number],
154+
};
155+
156+
assert.ok(typeof tensorInfo.dtype === "string");
157+
assert.ok(["F4", "F6_E2M3", "F6_E3M2", "E8M0"].includes(tensorInfo.dtype));
158+
}
159+
});
160+
161+
it("should handle parameter counting with sub-byte data types", () => {
162+
const mockHeader = {
163+
tensor_f4: {
164+
dtype: "F4" as const,
165+
shape: [10, 20],
166+
data_offsets: [0, 100] as [number, number],
167+
},
168+
tensor_f6_e2m3: {
169+
dtype: "F6_E2M3" as const,
170+
shape: [5, 10],
171+
data_offsets: [100, 150] as [number, number],
172+
},
173+
tensor_f6_e3m2: {
174+
dtype: "F6_E3M2" as const,
175+
shape: [8, 12],
176+
data_offsets: [150, 246] as [number, number],
177+
},
178+
tensor_e8m0: {
179+
dtype: "E8M0" as const,
180+
shape: [4, 6],
181+
data_offsets: [246, 270] as [number, number],
182+
},
183+
__metadata__: { format: "pt" },
184+
};
185+
186+
const computeNumOfParamsByDtypeSingleFile = (header: typeof mockHeader) => {
187+
const counter: Partial<Record<string, number>> = {};
188+
const tensors = Object.fromEntries(Object.entries(header).filter(([key]) => key !== "__metadata__"));
189+
190+
for (const [, v] of Object.entries(tensors) as [
191+
string,
192+
{ dtype: string; shape: number[]; data_offsets: [number, number] },
193+
][]) {
194+
if (v.shape.length === 0) {
195+
continue;
196+
}
197+
counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a: number, b: number) => a * b);
198+
}
199+
return counter;
200+
};
201+
202+
const parameterCount = computeNumOfParamsByDtypeSingleFile(mockHeader);
203+
204+
assert.strictEqual(parameterCount.F4, 200);
205+
assert.strictEqual(parameterCount.F6_E2M3, 50);
206+
assert.strictEqual(parameterCount.F6_E3M2, 96);
207+
assert.strictEqual(parameterCount.E8M0, 24);
208+
});
209+
210+
it("fetch info for openai/gpt-oss-20b (large sharded model)", async () => {
211+
const parse = await parseSafetensorsMetadata({
212+
repo: "openai/gpt-oss-20b",
213+
computeParametersCount: true,
214+
revision: "bbf09307421df45099c1e7dcbd64e3106ce5b403",
215+
});
216+
217+
assert(parse.sharded);
218+
219+
assert.ok(Object.keys(parse.headers).length > 1);
220+
assert.ok(parse.parameterCount);
221+
222+
const totalParams = parse.parameterTotal || sum(Object.values(parse.parameterCount));
223+
224+
assert.strictEqual(totalParams, 21_511_953_984); // 21.5B
225+
226+
assert.ok(parse.parameterCount.BF16 && parse.parameterCount.U8);
227+
228+
assert.strictEqual(Object.keys(parse.headers).length, 3);
229+
});
230+
231+
it("should support FP4 and UE8 data types in type system", () => {
232+
const newDataTypes: Array<"FP4" | "UE8"> = ["FP4", "UE8"];
233+
234+
for (const dtype of newDataTypes) {
235+
const tensorInfo = {
236+
dtype,
237+
shape: [1, 2],
238+
data_offsets: [0, 1] as [number, number],
239+
};
240+
241+
assert.ok(typeof tensorInfo.dtype === "string");
242+
assert.ok(["FP4", "UE8"].includes(tensorInfo.dtype));
243+
}
244+
245+
const mockHeader = {
246+
tensor_fp4: {
247+
dtype: "FP4" as const,
248+
shape: [100, 200],
249+
data_offsets: [0, 5000] as [number, number],
250+
},
251+
tensor_ue8: {
252+
dtype: "UE8" as const,
253+
shape: [50, 100],
254+
data_offsets: [5000, 10000] as [number, number],
255+
},
256+
__metadata__: { format: "pt" },
257+
};
258+
259+
const computeNumOfParamsByDtypeSingleFile = (header: typeof mockHeader) => {
260+
const counter: Partial<Record<string, number>> = {};
261+
const tensors = Object.fromEntries(Object.entries(header).filter(([key]) => key !== "__metadata__"));
262+
263+
for (const [, v] of Object.entries(tensors) as [
264+
string,
265+
{ dtype: string; shape: number[]; data_offsets: [number, number] },
266+
][]) {
267+
if (v.shape.length === 0) {
268+
continue;
269+
}
270+
counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a: number, b: number) => a * b);
271+
}
272+
return counter;
273+
};
274+
275+
const parameterCount = computeNumOfParamsByDtypeSingleFile(mockHeader);
276+
277+
assert.strictEqual(parameterCount.FP4, 20000);
278+
assert.strictEqual(parameterCount.UE8, 5000);
279+
});
145280
});

packages/hub/src/lib/parse-safetensors-metadata.ts

Lines changed: 137 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,19 @@ export type Dtype =
4848
| "F16"
4949
| "F8_E4M3"
5050
| "F8_E5M2"
51+
| "E8M0"
52+
| "F6_E3M2"
53+
| "F6_E2M3"
54+
| "F4"
55+
| "FP4"
5156
| "BF16"
5257
| "I64"
5358
| "I32"
5459
| "I16"
5560
| "I8"
5661
| "U16"
5762
| "U8"
63+
| "UE8"
5864
| "BOOL";
5965

6066
export interface TensorInfo {
@@ -92,6 +98,35 @@ export type SafetensorsParseFromRepo =
9298
parameterTotal?: number;
9399
};
94100

101+
/**
102+
* Fetches and parses model config.json
103+
*/
104+
async function fetchModelConfig(
105+
params: {
106+
repo: RepoDesignation;
107+
revision?: string;
108+
hubUrl?: string;
109+
fetch?: typeof fetch;
110+
} & Partial<CredentialsParams>
111+
): Promise<ModelConfig | null> {
112+
try {
113+
const configBlob = await downloadFile({
114+
...params,
115+
path: "config.json",
116+
});
117+
118+
if (!configBlob) {
119+
return null;
120+
}
121+
122+
const config = JSON.parse(await configBlob.text());
123+
return config as ModelConfig;
124+
} catch (error) {
125+
// Config file might not exist or be inaccessible
126+
return null;
127+
}
128+
}
129+
95130
async function parseSingleFile(
96131
path: string,
97132
params: {
@@ -252,6 +287,10 @@ export async function parseSafetensorsMetadata(
252287
throw new TypeError("Only model repos should contain safetensors files.");
253288
}
254289

290+
// Fetch model config for quantization information
291+
const modelConfig = params.computeParametersCount ? await fetchModelConfig(params) : null;
292+
const quantConfig = modelConfig?.quantization_config;
293+
255294
if (
256295
(params.path && RE_SAFETENSORS_FILE.test(params.path)) ||
257296
(await fileExists({ ...params, path: SAFETENSORS_FILE }))
@@ -262,17 +301,17 @@ export async function parseSafetensorsMetadata(
262301
header,
263302
...(params.computeParametersCount
264303
? {
265-
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
304+
parameterCount: computeNumOfParamsByDtypeSingleFile(header, quantConfig),
266305
parameterTotal:
267306
/// shortcut: get param count directly from metadata
268307
header.__metadata__.total_parameters
269308
? typeof header.__metadata__.total_parameters === "number"
270309
? header.__metadata__.total_parameters
271310
: typeof header.__metadata__.total_parameters === "string"
272-
? parseInt(header.__metadata__.total_parameters)
273-
: undefined
311+
? parseInt(header.__metadata__.total_parameters)
312+
: undefined
274313
: undefined,
275-
}
314+
}
276315
: undefined),
277316
};
278317
} else if (
@@ -289,41 +328,126 @@ export async function parseSafetensorsMetadata(
289328
headers: shardedMap,
290329
...(params.computeParametersCount
291330
? {
292-
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap),
331+
parameterCount: computeNumOfParamsByDtypeSharded(shardedMap, quantConfig),
293332
parameterTotal:
294333
/// shortcut: get param count directly from metadata
295334
index.metadata?.total_parameters
296335
? typeof index.metadata.total_parameters === "number"
297336
? index.metadata.total_parameters
298337
: typeof index.metadata.total_parameters === "string"
299-
? parseInt(index.metadata.total_parameters)
300-
: undefined
338+
? parseInt(index.metadata.total_parameters)
339+
: undefined
301340
: undefined,
302-
}
341+
}
303342
: undefined),
304343
};
305344
} else {
306345
throw new Error("model id does not seem to contain safetensors weights");
307346
}
308347
}
309348

310-
function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial<Record<Dtype, number>> {
349+
export interface QuantizationConfig {
350+
quant_method?: string;
351+
modules_to_not_convert?: string[];
352+
bits?: number;
353+
load_in_4bit?: boolean;
354+
load_in_8bit?: boolean;
355+
}
356+
357+
export interface ModelConfig {
358+
quantization_config?: QuantizationConfig;
359+
}
360+
361+
/**
362+
* Determines if a tensor is quantized based on quantization config and tensor name
363+
*/
364+
function isQuantizedTensor(tensorName: string, quantConfig?: QuantizationConfig): boolean {
365+
if (!quantConfig || !quantConfig.modules_to_not_convert) {
366+
return false;
367+
}
368+
369+
for (const pattern of quantConfig.modules_to_not_convert) {
370+
const regexPattern = pattern.replace(/\*/g, ".*");
371+
const regex = new RegExp(regexPattern);
372+
if (regex.test(tensorName)) {
373+
return false;
374+
}
375+
}
376+
377+
return true;
378+
}
379+
380+
/**
381+
* Gets the parameter multiplier for a quantized tensor based on quantization method
382+
*/
383+
function getQuantizationMultiplier(tensorName: string, dtype: Dtype, quantConfig?: QuantizationConfig): number {
384+
if (!quantConfig || !isQuantizedTensor(tensorName, quantConfig)) {
385+
return 1;
386+
}
387+
388+
switch (quantConfig.quant_method) {
389+
case "mxfp4":
390+
if (dtype === "U8" && tensorName.includes("_blocks")) {
391+
return 2;
392+
}
393+
return 1;
394+
395+
case "gptq":
396+
case "awq":
397+
if (quantConfig.bits === 4 && dtype === "U8") {
398+
return 2;
399+
}
400+
if (quantConfig.bits === 2 && dtype === "U8") {
401+
return 4;
402+
}
403+
return 1;
404+
405+
case "bitsandbytes":
406+
if (quantConfig.load_in_4bit && dtype === "U8") {
407+
return 2;
408+
}
409+
if (quantConfig.load_in_8bit) {
410+
return 1;
411+
}
412+
return 1;
413+
414+
default:
415+
if (quantConfig.load_in_4bit && dtype === "U8") {
416+
return 2;
417+
}
418+
if (quantConfig.bits === 4 && dtype === "U8") {
419+
return 2;
420+
}
421+
return 1;
422+
}
423+
}
424+
425+
function computeNumOfParamsByDtypeSingleFile(
426+
header: SafetensorsFileHeader,
427+
quantConfig?: QuantizationConfig
428+
): Partial<Record<Dtype, number>> {
311429
const counter: Partial<Record<Dtype, number>> = {};
312430
const tensors = omit(header, "__metadata__");
313431

314-
for (const [, v] of typedEntries(tensors)) {
432+
for (const [tensorName, v] of typedEntries(tensors)) {
315433
if (v.shape.length === 0) {
316434
continue;
317435
}
318-
counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a, b) => a * b);
436+
437+
const elements = v.shape.reduce((a, b) => a * b);
438+
const multiplier = quantConfig ? getQuantizationMultiplier(tensorName, v.dtype, quantConfig) : 1;
439+
counter[v.dtype] = (counter[v.dtype] ?? 0) + elements * multiplier;
319440
}
320441
return counter;
321442
}
322443

323-
function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial<Record<Dtype, number>> {
444+
function computeNumOfParamsByDtypeSharded(
445+
shardedMap: SafetensorsShardedHeaders,
446+
quantConfig?: QuantizationConfig
447+
): Partial<Record<Dtype, number>> {
324448
const counter: Partial<Record<Dtype, number>> = {};
325449
for (const header of Object.values(shardedMap)) {
326-
for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) {
450+
for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header, quantConfig))) {
327451
counter[k] = (counter[k] ?? 0) + (v ?? 0);
328452
}
329453
}

0 commit comments

Comments
 (0)