From 441527dc0bdba9cbbbcc0a838f14f79e35409892 Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 26 Jun 2025 10:02:06 -0700 Subject: [PATCH] [benchmark] Add HF LLM benchmarks --- benchmarks/dynamo/common.py | 157 ++++++++++++------ benchmarks/dynamo/huggingface.py | 51 ++++-- benchmarks/dynamo/huggingface.yaml | 5 + benchmarks/dynamo/huggingface_llm_models.py | 96 +++++++++++ benchmarks/dynamo/huggingface_models_list.txt | 5 + 5 files changed, 253 insertions(+), 61 deletions(-) create mode 100644 benchmarks/dynamo/huggingface_llm_models.py diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 651bc90ba194..60de7493a18f 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1098,14 +1098,27 @@ def maybe_mark_profile(*args, **kwargs): torch._dynamo.config.repro_tolerance = tolerance with maybe_profile(args.export_profiler_trace, **args.profile_details) as p: - if args.export_aot_inductor: - frozen_model_iter_fn = export_aot_inductor( - model, example_inputs, args.inductor_compile_mode - ) - elif args.export_nativert: - frozen_model_iter_fn = export_nativert(model, example_inputs) + use_generate_mode = kwargs.get("use_generate_mode", False) + if use_generate_mode: + assert not args.training + + if args.export_aot_inductor: + model.forward = export_aot_inductor_simple( + model, example_inputs, args.inductor_compile_mode + ) + elif args.export_nativert: + frozen_model_iter_fn = export_nativert(model, example_inputs) + else: + model.forward = torch._dynamo.run(model) + + frozen_model_iter_fn = model_iter_fn else: - frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) + if args.export_aot_inductor: + frozen_model_iter_fn = export_aot_inductor( + model, example_inputs, args.inductor_compile_mode + ) + else: + frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) for rep in trange(args.repeat, desc="running benchmark"): inputs = ( @@ -1120,15 +1133,16 @@ def maybe_mark_profile(*args, **kwargs): # interleave the runs to handle frequency scaling and load changes with maybe_mark_profile(p=p, mark="expected"): - timings[rep, 0], expected_output = timed( - model, - model_iter_fn, - inputs, - return_result=True, - times=times, - collect_outputs=args.collect_outputs, - batch_size=kwargs.get("batch_size"), - ) + with torch.compiler.set_stance("force_eager"): + timings[rep, 0], expected_output = timed( + model, + model_iter_fn, + inputs, + return_result=True, + times=times, + collect_outputs=args.collect_outputs, + batch_size=kwargs.get("batch_size"), + ) # call mark_step between the 2 calls to make the comparison fair. maybe_mark_step(args) @@ -1518,8 +1532,12 @@ def opt_nativert(_, example_inputs, collect_outputs=False): return opt_nativert +def export_aot_inductor_simple(model, example_inputs, mode): + return AOTInductorModelCache.load(model, example_inputs, mode) + + def export_aot_inductor(model, example_inputs, mode): - optimized = AOTInductorModelCache.load(model, example_inputs, mode) + optimized = export_aot_inductor_simple(model, example_inputs, mode) def opt_aot_inductor(_, example_inputs, collect_outputs=False): example_args, example_kwargs = _normalize_bench_inputs(example_inputs) @@ -2200,11 +2218,12 @@ def record_status(accuracy_status, dynamo_start_stats): reset_rng_state() model_copy = None try: - model_copy = self.deepcopy_and_maybe_parallelize(model) - self.init_optimizer(name, current_device, model_copy.parameters()) - correct_result = self.run_n_iterations( - model_copy, clone_inputs(example_inputs), self.model_iter_fn - ) + with torch.compiler.set_stance("force_eager"): + model_copy = self.deepcopy_and_maybe_parallelize(model) + self.init_optimizer(name, current_device, model_copy.parameters()) + correct_result = self.run_n_iterations( + model_copy, clone_inputs(example_inputs), self.model_iter_fn + ) except Exception as e: accuracy_status = ( "eager_1st_run_OOM" @@ -2221,11 +2240,12 @@ def record_status(accuracy_status, dynamo_start_stats): reset_rng_state() model_copy = None try: - model_copy = self.deepcopy_and_maybe_parallelize(model) - self.init_optimizer(name, current_device, model_copy.parameters()) - correct_rerun_result = self.run_n_iterations( - model_copy, clone_inputs(example_inputs), self.model_iter_fn - ) + with torch.compiler.set_stance("force_eager"): + model_copy = self.deepcopy_and_maybe_parallelize(model) + self.init_optimizer(name, current_device, model_copy.parameters()) + correct_rerun_result = self.run_n_iterations( + model_copy, clone_inputs(example_inputs), self.model_iter_fn + ) except Exception as e: accuracy_status = ( "eager_2nd_run_OOM" @@ -2274,6 +2294,11 @@ def record_status(accuracy_status, dynamo_start_stats): try: model_copy = self.deepcopy_and_maybe_parallelize(model) self.init_optimizer(name, current_device, model_copy.parameters()) + + use_generate_mode = getattr(self, "use_generate_mode", False) + if use_generate_mode: + assert not self.args.training + if ( self.args.export or self.args.export_aot_inductor @@ -2286,12 +2311,23 @@ def record_status(accuracy_status, dynamo_start_stats): optimized_model_iter_fn = optimize_ctx( model_copy, example_inputs ) - new_result = optimized_model_iter_fn(model_copy, example_inputs) + if use_generate_mode: + new_result = self.model_iter_fn(model_copy, example_inputs) + else: + new_result = optimized_model_iter_fn( + model_copy, example_inputs + ) else: - optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) - new_result = self.run_n_iterations( - model_copy, example_inputs, optimized_model_iter_fn - ) + if use_generate_mode: + optimized_model = optimize_ctx(model_copy) + new_result = self.run_n_iterations( + optimized_model, example_inputs, self.model_iter_fn + ) + else: + optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) + new_result = self.run_n_iterations( + model_copy, example_inputs, optimized_model_iter_fn + ) except Exception as e: log.exception("") print( @@ -2507,14 +2543,22 @@ def warmup(fn, model, example_inputs, mode, niters=10): self.model_iter_fn, model, example_inputs, "eager", niters=1 ) - baseline_timings = experiment( - model, example_inputs, mark="expected", **experiment_kwargs - ) + with torch.compiler.set_stance("force_eager"): + baseline_timings = experiment( + model, example_inputs, mark="expected", **experiment_kwargs + ) + + use_generate_mode = getattr(self, "use_generate_mode", False) + if use_generate_mode: + assert not self.args.training + optimized_model_iter_fn = self.model_iter_fn + model = optimize_ctx(model) - if self.args.export_aot_inductor: - optimized_model_iter_fn = optimize_ctx else: - optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) + if self.args.export_aot_inductor: + optimized_model_iter_fn = optimize_ctx + else: + optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) with maybe_snapshot_memory( self.args.snapshot_memory, f"compiled_{self.args.only}" @@ -2662,22 +2706,37 @@ def warmup(fn, model, example_inputs, mode, niters=5): with maybe_snapshot_memory( self.args.snapshot_memory, f"eager_{self.args.only}" ): - eager_latency, eager_peak_mem, _ = warmup( - self.model_iter_fn, copy.deepcopy(model), example_inputs, "eager" - ) - if self.args.use_warm_peak_memory: - _, eager_peak_mem, _ = warmup( + with torch.compiler.set_stance("force_eager"): + eager_latency, eager_peak_mem, _ = warmup( self.model_iter_fn, copy.deepcopy(model), example_inputs, "eager", - niters=1, ) + if self.args.use_warm_peak_memory: + _, eager_peak_mem, _ = warmup( + self.model_iter_fn, + copy.deepcopy(model), + example_inputs, + "eager", + niters=1, + ) + + use_generate_mode = getattr(self, "use_generate_mode", False) + experiment_kwargs["use_generate_mode"] = use_generate_mode + if use_generate_mode: + assert not self.args.training + optimized_model_iter_fn = self.model_iter_fn + if self.args.export_aot_inductor or self.args.export_nativert: + model.forward = optimize_ctx + else: + model.forward = optimize_ctx(model) - if self.args.export_aot_inductor or self.args.export_nativert: - optimized_model_iter_fn = optimize_ctx else: - optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) + if self.args.export_aot_inductor or self.args.export_nativert: + optimized_model_iter_fn = optimize_ctx + else: + optimized_model_iter_fn = optimize_ctx(self.model_iter_fn) with maybe_snapshot_memory( self.args.snapshot_memory, f"compiled_{self.args.only}" @@ -4037,8 +4096,6 @@ def model_iter_fn_and_mark_step(*args, **kwargs): # Overwrite 'translation_validation' config, if specified. torch.fx.experimental._config.translation_validation = False - experiment = functools.partial(experiment, args, runner.model_iter_fn) - if args.only and should_diff_branch(args): import git @@ -4227,6 +4284,8 @@ def detect_and_mark_batch(t): if name in runner.inline_inbuilt_nn_modules_models: inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) + experiment = functools.partial(experiment, args, runner.model_iter_fn) + with guard_ctx: with inline_ctx: runner.run_one_model( diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 916a33276d99..f7e13821b87f 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -128,6 +128,12 @@ def process_hf_reformer_output(out): assert len(BATCH_SIZE_KNOWN_MODELS) +try: + from .huggingface_llm_models import HF_LLM_MODELS +except ImportError: + from huggingface_llm_models import HF_LLM_MODELS + + def get_module_cls_by_model_name(model_cls_name): _module_by_model_name = { "Speech2Text2Decoder": "transformers.models.speech_to_text_2.modeling_speech_to_text_2", @@ -417,11 +423,8 @@ def load_model( use_eval_mode = self.args.use_eval_mode dtype = torch.float32 reset_rng_state() - model_cls, config = self._get_model_cls_and_config(model_name) - model = self._download_model(model_name) - model = model.to(device, dtype=dtype) - if self.args.enable_activation_checkpointing: - model.gradient_checkpointing_enable() + + # Get batch size if model_name in BATCH_SIZE_KNOWN_MODELS: batch_size_default = BATCH_SIZE_KNOWN_MODELS[model_name] elif batch_size is None: @@ -439,14 +442,35 @@ def load_model( f"Running smaller batch size={batch_size} for {model_name}, orig batch_size={batch_size_default}" # noqa: G004 ) - example_inputs = generate_inputs_for_model( - model_cls, model, model_name, batch_size, device, include_loss_args=True - ) + # Get model and example inputs + if model_name in HF_LLM_MODELS: + benchmark_cls = HF_LLM_MODELS[model_name] + model, example_inputs = benchmark_cls.get_model_and_inputs( + model_name, device + ) - # So we can check for correct gradients without eliminating the dropout computation - for attr in dir(config): - if "drop" in attr and isinstance(getattr(config, attr), float): - setattr(config, attr, 1e-30) + # If we set use_generate_mode to True, we will only apply the + # optimizations (torch.compile/export) to model.forward, and not + # apply to self.model_iter_fn, which calls model.generate. + self.use_generate_mode = True + self.model_iter_fn = self.generate + + else: + model_cls, config = self._get_model_cls_and_config(model_name) + model = self._download_model(model_name) + model = model.to(device, dtype=dtype) + + example_inputs = generate_inputs_for_model( + model_cls, model, model_name, batch_size, device, include_loss_args=True + ) + + # So we can check for correct gradients without eliminating the dropout computation + for attr in dir(config): + if "drop" in attr and isinstance(getattr(config, attr), float): + setattr(config, attr, 1e-30) + + if self.args.enable_activation_checkpointing: + model.gradient_checkpointing_enable() if ( is_training @@ -525,6 +549,9 @@ def forward_pass(self, mod, inputs, collect_outputs=True): with self.autocast(**self.autocast_arg): return mod(**inputs) + def generate(self, mod, inputs, collect_outputs=True): + return mod.generate(**inputs) + def forward_and_backward_pass(self, mod, inputs, collect_outputs=True): cloned_inputs = clone_inputs(inputs) self.optimizer_zero_grad(mod) diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index f0ee57a58965..af5e0f862414 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -70,6 +70,11 @@ batch_size: XGLMForCausalLM: 4 XLNetLMHeadModel: 2 YituTechConvBert: 2 + meta-llama/Llama-3.2-1B: 8 + google/gemma-2-2b: 8 + google/gemma-3-4b-it: 8 + openai/whisper-tiny: 8 + Qwen/Qwen3-0.6B: 8 tolerance: diff --git a/benchmarks/dynamo/huggingface_llm_models.py b/benchmarks/dynamo/huggingface_llm_models.py new file mode 100644 index 000000000000..fd06be21a4c0 --- /dev/null +++ b/benchmarks/dynamo/huggingface_llm_models.py @@ -0,0 +1,96 @@ +import subprocess +import sys + +import torch + + +def pip_install(package): + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + + +try: + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + WhisperForConditionalGeneration, + WhisperProcessor, + ) +except ModuleNotFoundError: + print("Installing HuggingFace Transformers...") + pip_install("git+https://github.com/huggingface/transformers.git#egg=transformers") +finally: + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + WhisperForConditionalGeneration, + WhisperProcessor, + ) + + +class Benchmark: + @staticmethod + def get_model_and_inputs(model_name, device): + raise NotImplementedError("get_model_and_inputs() not implemented") + + +class WhisperBenchmark(Benchmark): + SAMPLE_RATE = 16000 + DURATION = 30.0 # seconds + + @staticmethod + def get_model_and_inputs(model_name, device): + processor = WhisperProcessor.from_pretrained(model_name) + model = WhisperForConditionalGeneration.from_pretrained(model_name).to(device) + model.config.forced_decoder_ids = None + + model.generation_config.do_sample = False + model.generation_config.temperature = 0.0 + + num_samples = int(WhisperBenchmark.DURATION * WhisperBenchmark.SAMPLE_RATE) + audio = torch.randn(num_samples) * 0.1 + inputs = processor( + audio, sampling_rate=WhisperBenchmark.SAMPLE_RATE, return_tensors="pt" + ) + inputs["input_features"] = inputs["input_features"].to(device) + + return model, dict(inputs) + + +class TextGenerationBenchmark(Benchmark): + INPUT_LENGTH = 1000 + OUTPUT_LENGTH = 2000 + + @staticmethod + def get_model_and_inputs(model_name, device): + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + ).to(device) + model.tokenizer = tokenizer + + model.generation_config.do_sample = False + model.generation_config.use_cache = True + model.generation_config.cache_implementation = "static" + model.generation_config.max_new_tokens = 2000 + model.generation_config.pad_token_id = tokenizer.eos_token_id + model.generation_config.temperature = 0.0 + + vocab_size = tokenizer.vocab_size + input_ids = torch.randint( + low=0, + high=vocab_size, + size=(1, TextGenerationBenchmark.INPUT_LENGTH), + device=device, + dtype=torch.long, + ) + + return model, {"input_ids": input_ids} + + +HF_LLM_MODELS: dict[str, Benchmark] = { + "meta-llama/Llama-3.2-1B": TextGenerationBenchmark, + "google/gemma-2-2b": TextGenerationBenchmark, + "google/gemma-3-4b-it": TextGenerationBenchmark, + "openai/whisper-tiny": WhisperBenchmark, + "Qwen/Qwen3-0.6B": TextGenerationBenchmark, +} diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt index 6e3cf19a783d..a6ca7c6a7eeb 100644 --- a/benchmarks/dynamo/huggingface_models_list.txt +++ b/benchmarks/dynamo/huggingface_models_list.txt @@ -49,3 +49,8 @@ TrOCRForCausalLM,64 XGLMForCausalLM,32 XLNetLMHeadModel,16 YituTechConvBert,32 +meta-llama/Llama-3.2-1B,8 +google/gemma-2-2b,8 +google/gemma-3-4b-it,8 +openai/whisper-tiny,8 +Qwen/Qwen3-0.6B,8