Skip to content

[benchmark] Add HF LLM benchmarks #156967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 108 additions & 49 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,14 +1098,27 @@ def maybe_mark_profile(*args, **kwargs):
torch._dynamo.config.repro_tolerance = tolerance

with maybe_profile(args.export_profiler_trace, **args.profile_details) as p:
if args.export_aot_inductor:
frozen_model_iter_fn = export_aot_inductor(
model, example_inputs, args.inductor_compile_mode
)
elif args.export_nativert:
frozen_model_iter_fn = export_nativert(model, example_inputs)
use_generate_mode = kwargs.get("use_generate_mode", False)
if use_generate_mode:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this use_generate_mode flag so that we only apply torch.compile/export to model.forward, instead of applying it to model.generate

assert not args.training

if args.export_aot_inductor:
model.forward = export_aot_inductor_simple(
model, example_inputs, args.inductor_compile_mode
)
elif args.export_nativert:
frozen_model_iter_fn = export_nativert(model, example_inputs)
else:
model.forward = torch._dynamo.run(model)

frozen_model_iter_fn = model_iter_fn
else:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
if args.export_aot_inductor:
frozen_model_iter_fn = export_aot_inductor(
model, example_inputs, args.inductor_compile_mode
)
else:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)

for rep in trange(args.repeat, desc="running benchmark"):
inputs = (
Expand All @@ -1120,15 +1133,16 @@ def maybe_mark_profile(*args, **kwargs):

# interleave the runs to handle frequency scaling and load changes
with maybe_mark_profile(p=p, mark="expected"):
timings[rep, 0], expected_output = timed(
model,
model_iter_fn,
inputs,
return_result=True,
times=times,
collect_outputs=args.collect_outputs,
batch_size=kwargs.get("batch_size"),
)
with torch.compiler.set_stance("force_eager"):
timings[rep, 0], expected_output = timed(
model,
model_iter_fn,
inputs,
return_result=True,
times=times,
collect_outputs=args.collect_outputs,
batch_size=kwargs.get("batch_size"),
)

# call mark_step between the 2 calls to make the comparison fair.
maybe_mark_step(args)
Expand Down Expand Up @@ -1518,8 +1532,12 @@ def opt_nativert(_, example_inputs, collect_outputs=False):
return opt_nativert


def export_aot_inductor_simple(model, example_inputs, mode):
return AOTInductorModelCache.load(model, example_inputs, mode)


def export_aot_inductor(model, example_inputs, mode):
optimized = AOTInductorModelCache.load(model, example_inputs, mode)
optimized = export_aot_inductor_simple(model, example_inputs, mode)

def opt_aot_inductor(_, example_inputs, collect_outputs=False):
example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
Expand Down Expand Up @@ -2200,11 +2218,12 @@ def record_status(accuracy_status, dynamo_start_stats):
reset_rng_state()
model_copy = None
try:
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_result = self.run_n_iterations(
model_copy, clone_inputs(example_inputs), self.model_iter_fn
)
with torch.compiler.set_stance("force_eager"):
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_result = self.run_n_iterations(
model_copy, clone_inputs(example_inputs), self.model_iter_fn
)
except Exception as e:
accuracy_status = (
"eager_1st_run_OOM"
Expand All @@ -2221,11 +2240,12 @@ def record_status(accuracy_status, dynamo_start_stats):
reset_rng_state()
model_copy = None
try:
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_rerun_result = self.run_n_iterations(
model_copy, clone_inputs(example_inputs), self.model_iter_fn
)
with torch.compiler.set_stance("force_eager"):
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())
correct_rerun_result = self.run_n_iterations(
model_copy, clone_inputs(example_inputs), self.model_iter_fn
)
except Exception as e:
accuracy_status = (
"eager_2nd_run_OOM"
Expand Down Expand Up @@ -2274,6 +2294,11 @@ def record_status(accuracy_status, dynamo_start_stats):
try:
model_copy = self.deepcopy_and_maybe_parallelize(model)
self.init_optimizer(name, current_device, model_copy.parameters())

use_generate_mode = getattr(self, "use_generate_mode", False)
if use_generate_mode:
assert not self.args.training

if (
self.args.export
or self.args.export_aot_inductor
Expand All @@ -2286,12 +2311,23 @@ def record_status(accuracy_status, dynamo_start_stats):
optimized_model_iter_fn = optimize_ctx(
model_copy, example_inputs
)
new_result = optimized_model_iter_fn(model_copy, example_inputs)
if use_generate_mode:
new_result = self.model_iter_fn(model_copy, example_inputs)
else:
new_result = optimized_model_iter_fn(
model_copy, example_inputs
)
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
if use_generate_mode:
optimized_model = optimize_ctx(model_copy)
new_result = self.run_n_iterations(
optimized_model, example_inputs, self.model_iter_fn
)
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
new_result = self.run_n_iterations(
model_copy, example_inputs, optimized_model_iter_fn
)
except Exception as e:
log.exception("")
print(
Expand Down Expand Up @@ -2507,14 +2543,22 @@ def warmup(fn, model, example_inputs, mode, niters=10):
self.model_iter_fn, model, example_inputs, "eager", niters=1
)

baseline_timings = experiment(
model, example_inputs, mark="expected", **experiment_kwargs
)
with torch.compiler.set_stance("force_eager"):
baseline_timings = experiment(
model, example_inputs, mark="expected", **experiment_kwargs
)

use_generate_mode = getattr(self, "use_generate_mode", False)
if use_generate_mode:
assert not self.args.training
optimized_model_iter_fn = self.model_iter_fn
model = optimize_ctx(model)

if self.args.export_aot_inductor:
optimized_model_iter_fn = optimize_ctx
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
if self.args.export_aot_inductor:
optimized_model_iter_fn = optimize_ctx
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)

with maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
Expand Down Expand Up @@ -2662,22 +2706,37 @@ def warmup(fn, model, example_inputs, mode, niters=5):
with maybe_snapshot_memory(
self.args.snapshot_memory, f"eager_{self.args.only}"
):
eager_latency, eager_peak_mem, _ = warmup(
self.model_iter_fn, copy.deepcopy(model), example_inputs, "eager"
)
if self.args.use_warm_peak_memory:
_, eager_peak_mem, _ = warmup(
with torch.compiler.set_stance("force_eager"):
eager_latency, eager_peak_mem, _ = warmup(
self.model_iter_fn,
copy.deepcopy(model),
example_inputs,
"eager",
niters=1,
)
if self.args.use_warm_peak_memory:
_, eager_peak_mem, _ = warmup(
self.model_iter_fn,
copy.deepcopy(model),
example_inputs,
"eager",
niters=1,
)

use_generate_mode = getattr(self, "use_generate_mode", False)
experiment_kwargs["use_generate_mode"] = use_generate_mode
if use_generate_mode:
assert not self.args.training
optimized_model_iter_fn = self.model_iter_fn
if self.args.export_aot_inductor or self.args.export_nativert:
model.forward = optimize_ctx
else:
model.forward = optimize_ctx(model)

if self.args.export_aot_inductor or self.args.export_nativert:
optimized_model_iter_fn = optimize_ctx
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)
if self.args.export_aot_inductor or self.args.export_nativert:
optimized_model_iter_fn = optimize_ctx
else:
optimized_model_iter_fn = optimize_ctx(self.model_iter_fn)

with maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
Expand Down Expand Up @@ -4037,8 +4096,6 @@ def model_iter_fn_and_mark_step(*args, **kwargs):
# Overwrite 'translation_validation' config, if specified.
torch.fx.experimental._config.translation_validation = False

experiment = functools.partial(experiment, args, runner.model_iter_fn)

if args.only and should_diff_branch(args):
import git

Expand Down Expand Up @@ -4227,6 +4284,8 @@ def detect_and_mark_batch(t):
if name in runner.inline_inbuilt_nn_modules_models:
inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)

experiment = functools.partial(experiment, args, runner.model_iter_fn)

with guard_ctx:
with inline_ctx:
runner.run_one_model(
Expand Down
51 changes: 39 additions & 12 deletions benchmarks/dynamo/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def process_hf_reformer_output(out):
assert len(BATCH_SIZE_KNOWN_MODELS)


try:
from .huggingface_llm_models import HF_LLM_MODELS
except ImportError:
from huggingface_llm_models import HF_LLM_MODELS


def get_module_cls_by_model_name(model_cls_name):
_module_by_model_name = {
"Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2",
Expand Down Expand Up @@ -417,11 +423,8 @@ def load_model(
use_eval_mode = self.args.use_eval_mode
dtype = torch.float32
reset_rng_state()
model_cls, config = self._get_model_cls_and_config(model_name)
model = self._download_model(model_name)
model = model.to(device, dtype=dtype)
if self.args.enable_activation_checkpointing:
model.gradient_checkpointing_enable()

# Get batch size
if model_name in BATCH_SIZE_KNOWN_MODELS:
batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name]
elif batch_size is None:
Expand All @@ -439,14 +442,35 @@ def load_model(
f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" # noqa: G004
)

example_inputs = generate_inputs_for_model(
model_cls, model, model_name, batch_size, device, include_loss_args=True
)
# Get model and example inputs
if model_name in HF_LLM_MODELS:
benchmark_cls = HF_LLM_MODELS[model_name]
model, example_inputs = benchmark_cls.get_model_and_inputs(
model_name, device
)

# So we can check for correct gradients without eliminating the dropout computation
for attr in dir(config):
if "drop" in attr and isinstance(getattr(config, attr), float):
setattr(config, attr, 1e-30)
# If we set use_generate_mode to True, we will only apply the
# optimizations (torch.compile/export) to model.forward, and not
# apply to self.model_iter_fn, which calls model.generate.
self.use_generate_mode = True
self.model_iter_fn = self.generate

else:
model_cls, config = self._get_model_cls_and_config(model_name)
model = self._download_model(model_name)
model = model.to(device, dtype=dtype)

example_inputs = generate_inputs_for_model(
model_cls, model, model_name, batch_size, device, include_loss_args=True
)

# So we can check for correct gradients without eliminating the dropout computation
for attr in dir(config):
if "drop" in attr and isinstance(getattr(config, attr), float):
setattr(config, attr, 1e-30)

if self.args.enable_activation_checkpointing:
model.gradient_checkpointing_enable()

if (
is_training
Expand Down Expand Up @@ -525,6 +549,9 @@ def forward_pass(self, mod, inputs, collect_outputs=True):
with self.autocast(**self.autocast_arg):
return mod(**inputs)

def generate(self, mod, inputs, collect_outputs=True):
return mod.generate(**inputs)

def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
cloned_inputs = clone_inputs(inputs)
self.optimizer_zero_grad(mod)
Expand Down
5 changes: 5 additions & 0 deletions benchmarks/dynamo/huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ batch_size:
XGLMForCausalLM: 4
XLNetLMHeadModel: 2
YituTechConvBert: 2
meta-llama/Llama-3.2-1B: 8
google/gemma-2-2b: 8
google/gemma-3-4b-it: 8
openai/whisper-tiny: 8
Qwen/Qwen3-0.6B: 8


tolerance:
Expand Down
Loading
Loading