Skip to content

Commit f48faf2

Browse files
committed
[benchmark] Add HF LLM benchmarks
1 parent 80cca83 commit f48faf2

File tree

5 files changed

+272
-61
lines changed

5 files changed

+272
-61
lines changed

benchmarks/dynamo/common.py

Lines changed: 105 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,14 +1098,27 @@ def maybe_mark_profile(*args, **kwargs):
10981098
torch._dynamo.config.repro_tolerance = tolerance
10991099

11001100
with maybe_profile(args.export_profiler_trace, **args.profile_details) as p:
1101-
if args.export_aot_inductor:
1102-
frozen_model_iter_fn = export_aot_inductor(
1103-
model, example_inputs, args.inductor_compile_mode
1104-
)
1105-
elif args.export_nativert:
1106-
frozen_model_iter_fn = export_nativert(model, example_inputs)
1101+
use_generate_mode = kwargs.get("use_generate_mode", False)
1102+
if use_generate_mode:
1103+
assert not args.training
1104+
1105+
if args.export_aot_inductor:
1106+
model.forward = export_aot_inductor_simple(
1107+
model, example_inputs, args.inductor_compile_mode
1108+
)
1109+
elif args.export_nativert:
1110+
frozen_model_iter_fn = export_nativert(model, example_inputs)
1111+
else:
1112+
model.forward = torch._dynamo.run(model)
1113+
1114+
frozen_model_iter_fn = model_iter_fn
11071115
else:
1108-
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
1116+
if args.export_aot_inductor:
1117+
frozen_model_iter_fn = export_aot_inductor(
1118+
model, example_inputs, args.inductor_compile_mode
1119+
)
1120+
else:
1121+
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
11091122

11101123
for rep in trange(args.repeat, desc="running benchmark"):
11111124
inputs = (
@@ -1120,15 +1133,16 @@ def maybe_mark_profile(*args, **kwargs):
11201133

11211134
# interleave the runs to handle frequency scaling and load changes
11221135
with maybe_mark_profile(p=p, mark="expected"):
1123-
timings[rep, 0], expected_output = timed(
1124-
model,
1125-
model_iter_fn,
1126-
inputs,
1127-
return_result=True,
1128-
times=times,
1129-
collect_outputs=args.collect_outputs,
1130-
batch_size=kwargs.get("batch_size"),
1131-
)
1136+
with torch.compiler.set_stance("force_eager"):
1137+
timings[rep, 0], expected_output = timed(
1138+
model,
1139+
model_iter_fn,
1140+
inputs,
1141+
return_result=True,
1142+
times=times,
1143+
collect_outputs=args.collect_outputs,
1144+
batch_size=kwargs.get("batch_size"),
1145+
)
11321146

11331147
# call mark_step between the 2 calls to make the comparison fair.
11341148
maybe_mark_step(args)
@@ -1518,8 +1532,12 @@ def opt_nativert(_, example_inputs, collect_outputs=False):
15181532
return opt_nativert
15191533

15201534

1535+
def export_aot_inductor_simple(model, example_inputs, mode):
1536+
return AOTInductorModelCache.load(model, example_inputs, mode)
1537+
1538+
15211539
def export_aot_inductor(model, example_inputs, mode):
1522-
optimized = AOTInductorModelCache.load(model, example_inputs, mode)
1540+
optimized = export_aot_inductor_simple(model, example_inputs, mode)
15231541

15241542
def opt_aot_inductor(_, example_inputs, collect_outputs=False):
15251543
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
@@ -2200,11 +2218,12 @@ def record_status(accuracy_status, dynamo_start_stats):
22002218
reset_rng_state()
22012219
model_copy = None
22022220
try:
2203-
model_copy = self.deepcopy_and_maybe_parallelize(model)
2204-
self.init_optimizer(name, current_device, model_copy.parameters())
2205-
correct_result = self.run_n_iterations(
2206-
model_copy, clone_inputs(example_inputs), self.model_iter_fn
2207-
)
2221+
with torch.compiler.set_stance("force_eager"):
2222+
model_copy = self.deepcopy_and_maybe_parallelize(model)
2223+
self.init_optimizer(name, current_device, model_copy.parameters())
2224+
correct_result = self.run_n_iterations(
2225+
model_copy, clone_inputs(example_inputs), self.model_iter_fn
2226+
)
22082227
except Exception as e:
22092228
accuracy_status = (
22102229
"eager_1st_run_OOM"
@@ -2221,11 +2240,12 @@ def record_status(accuracy_status, dynamo_start_stats):
22212240
reset_rng_state()
22222241
model_copy = None
22232242
try:
2224-
model_copy = self.deepcopy_and_maybe_parallelize(model)
2225-
self.init_optimizer(name, current_device, model_copy.parameters())
2226-
correct_rerun_result = self.run_n_iterations(
2227-
model_copy, clone_inputs(example_inputs), self.model_iter_fn
2228-
)
2243+
with torch.compiler.set_stance("force_eager"):
2244+
model_copy = self.deepcopy_and_maybe_parallelize(model)
2245+
self.init_optimizer(name, current_device, model_copy.parameters())
2246+
correct_rerun_result = self.run_n_iterations(
2247+
model_copy, clone_inputs(example_inputs), self.model_iter_fn
2248+
)
22292249
except Exception as e:
22302250
accuracy_status = (
22312251
"eager_2nd_run_OOM"
@@ -2274,6 +2294,11 @@ def record_status(accuracy_status, dynamo_start_stats):
22742294
try:
22752295
model_copy = self.deepcopy_and_maybe_parallelize(model)
22762296
self.init_optimizer(name, current_device, model_copy.parameters())
2297+
2298+
use_generate_mode = getattr(self, "use_generate_mode", False)
2299+
if use_generate_mode:
2300+
assert not self.args.training
2301+
22772302
if (
22782303
self.args.export
22792304
or self.args.export_aot_inductor
@@ -2286,12 +2311,23 @@ def record_status(accuracy_status, dynamo_start_stats):
22862311
optimized_model_iter_fn = optimize_ctx(
22872312
model_copy, example_inputs
22882313
)
2289-
new_result = optimized_model_iter_fn(model_copy, example_inputs)
2314+
if use_generate_mode:
2315+
new_result = self.model_iter_fn(model_copy, example_inputs)
2316+
else:
2317+
new_result = optimized_model_iter_fn(
2318+
model_copy, example_inputs
2319+
)
22902320
else:
2291-
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
2292-
new_result = self.run_n_iterations(
2293-
model_copy, example_inputs, optimized_model_iter_fn
2294-
)
2321+
if use_generate_mode:
2322+
optimized_model = optimize_ctx(model_copy)
2323+
new_result = self.run_n_iterations(
2324+
optimized_model, example_inputs, self.model_iter_fn
2325+
)
2326+
else:
2327+
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
2328+
new_result = self.run_n_iterations(
2329+
model_copy, example_inputs, optimized_model_iter_fn
2330+
)
22952331
except Exception as e:
22962332
log.exception("")
22972333
print(
@@ -2507,14 +2543,22 @@ def warmup(fn, model, example_inputs, mode, niters=10):
25072543
self.model_iter_fn, model, example_inputs, "eager", niters=1
25082544
)
25092545

2510-
baseline_timings = experiment(
2511-
model, example_inputs, mark="expected", **experiment_kwargs
2512-
)
2546+
with torch.compiler.set_stance("force_eager"):
2547+
baseline_timings = experiment(
2548+
model, example_inputs, mark="expected", **experiment_kwargs
2549+
)
2550+
2551+
use_generate_mode = getattr(self, "use_generate_mode", False)
2552+
if use_generate_mode:
2553+
assert not self.args.training
2554+
optimized_model_iter_fn = self.model_iter_fn
2555+
model = optimize_ctx(model)
25132556

2514-
if self.args.export_aot_inductor:
2515-
optimized_model_iter_fn = optimize_ctx
25162557
else:
2517-
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
2558+
if self.args.export_aot_inductor:
2559+
optimized_model_iter_fn = optimize_ctx
2560+
else:
2561+
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
25182562

25192563
with maybe_snapshot_memory(
25202564
self.args.snapshot_memory, f"compiled_{self.args.only}"
@@ -2662,22 +2706,34 @@ def warmup(fn, model, example_inputs, mode, niters=5):
26622706
with maybe_snapshot_memory(
26632707
self.args.snapshot_memory, f"eager_{self.args.only}"
26642708
):
2665-
eager_latency, eager_peak_mem, _ = warmup(
2666-
self.model_iter_fn, copy.deepcopy(model), example_inputs, "eager"
2667-
)
2668-
if self.args.use_warm_peak_memory:
2669-
_, eager_peak_mem, _ = warmup(
2709+
with torch.compiler.set_stance("force_eager"):
2710+
eager_latency, eager_peak_mem, _ = warmup(
26702711
self.model_iter_fn,
26712712
copy.deepcopy(model),
26722713
example_inputs,
26732714
"eager",
2674-
niters=1,
26752715
)
2716+
if self.args.use_warm_peak_memory:
2717+
_, eager_peak_mem, _ = warmup(
2718+
self.model_iter_fn,
2719+
copy.deepcopy(model),
2720+
example_inputs,
2721+
"eager",
2722+
niters=1,
2723+
)
2724+
2725+
use_generate_mode = getattr(self, "use_generate_mode", False)
2726+
experiment_kwargs["use_generate_mode"] = use_generate_mode
2727+
if use_generate_mode:
2728+
assert not self.args.training
2729+
optimized_model_iter_fn = self.model_iter_fn
2730+
model = optimize_ctx(model)
26762731

2677-
if self.args.export_aot_inductor or self.args.export_nativert:
2678-
optimized_model_iter_fn = optimize_ctx
26792732
else:
2680-
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
2733+
if self.args.export_aot_inductor or self.args.export_nativert:
2734+
optimized_model_iter_fn = optimize_ctx
2735+
else:
2736+
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
26812737

26822738
with maybe_snapshot_memory(
26832739
self.args.snapshot_memory, f"compiled_{self.args.only}"
@@ -4037,8 +4093,6 @@ def model_iter_fn_and_mark_step(*args, **kwargs):
40374093
# Overwrite 'translation_validation' config, if specified.
40384094
torch.fx.experimental._config.translation_validation = False
40394095

4040-
experiment = functools.partial(experiment, args, runner.model_iter_fn)
4041-
40424096
if args.only and should_diff_branch(args):
40434097
import git
40444098

@@ -4227,6 +4281,8 @@ def detect_and_mark_batch(t):
42274281
if name in runner.inline_inbuilt_nn_modules_models:
42284282
inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
42294283

4284+
experiment = functools.partial(experiment, args, runner.model_iter_fn)
4285+
42304286
with guard_ctx:
42314287
with inline_ctx:
42324288
runner.run_one_model(

benchmarks/dynamo/huggingface.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def process_hf_reformer_output(out):
128128
assert len(BATCH_SIZE_KNOWN_MODELS)
129129

130130

131+
try:
132+
from .huggingface_llm_models import HF_LLM_MODELS
133+
except ImportError:
134+
from huggingface_llm_models import HF_LLM_MODELS
135+
136+
131137
def get_module_cls_by_model_name(model_cls_name):
132138
_module_by_model_name = {
133139
"Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",
@@ -417,11 +423,8 @@ def load_model(
417423
use_eval_mode = self.args.use_eval_mode
418424
dtype = torch.float32
419425
reset_rng_state()
420-
model_cls, config = self._get_model_cls_and_config(model_name)
421-
model = self._download_model(model_name)
422-
model = model.to(device, dtype=dtype)
423-
if self.args.enable_activation_checkpointing:
424-
model.gradient_checkpointing_enable()
426+
427+
# Get batch size
425428
if model_name in BATCH_SIZE_KNOWN_MODELS:
426429
batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
427430
elif batch_size is None:
@@ -439,14 +442,39 @@ def load_model(
439442
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" # noqa: G004
440443
)
441444

442-
example_inputs = generate_inputs_for_model(
443-
model_cls, model, model_name, batch_size, device, include_loss_args=True
444-
)
445+
# Get model and example inputs
446+
if model_name in HF_LLM_MODELS:
447+
benchmark_cls = HF_LLM_MODELS[model_name]
448+
model, example_inputs = benchmark_cls.get_model_and_inputs(
449+
model_name, device
450+
)
445451

446-
# So we can check for correct gradients without eliminating the dropout computation
447-
for attr in dir(config):
448-
if "drop" in attr and isinstance(getattr(config, attr), float):
449-
setattr(config, attr, 1e-30)
452+
def model_iter_fn(model, inputs, collect_outputs=True):
453+
return benchmark_cls.model_iter_fn(self, model, inputs, collect_outputs)
454+
455+
self.model_iter_fn = model_iter_fn
456+
457+
# If we set use_generate_mode to True, we will only apply the
458+
# optimizations (torch.compile/export) to model.forward, and not
459+
# apply to self.model_iter_fn, which calls model.generate.
460+
self.use_generate_mode = True
461+
462+
else:
463+
model_cls, config = self._get_model_cls_and_config(model_name)
464+
model = self._download_model(model_name)
465+
model = model.to(device, dtype=dtype)
466+
467+
example_inputs = generate_inputs_for_model(
468+
model_cls, model, model_name, batch_size, device, include_loss_args=True
469+
)
470+
471+
# So we can check for correct gradients without eliminating the dropout computation
472+
for attr in dir(config):
473+
if "drop" in attr and isinstance(getattr(config, attr), float):
474+
setattr(config, attr, 1e-30)
475+
476+
if self.args.enable_activation_checkpointing:
477+
model.gradient_checkpointing_enable()
450478

451479
if (
452480
is_training

benchmarks/dynamo/huggingface.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ batch_size:
7070
XGLMForCausalLM: 4
7171
XLNetLMHeadModel: 2
7272
YituTechConvBert: 2
73+
meta-llama/Llama-3.2-1B: 8
74+
google/gemma-2-2b: 8
75+
google/gemma-3-4b-it: 8
76+
openai/whisper-tiny: 8
77+
Qwen/Qwen3-0.6B: 8
7378

7479

7580
tolerance:

0 commit comments

Comments
 (0)