Skip to content

torch.compile + Huggingface GenerationMixin #141196

@xmfan

Description

@xmfan

🐛 Describe the bug

error: https://gist.github.com/xmfan/7374fab55bdf73ba2501de15dd9de709

ValueError: The following `model_kwargs` are not used by the model: ['bos_token_id', 'pad_token_id', 'eos_token_id', 'max_length', 'do_sample', 'top_p', 'top_k', 'temperature', 'num_return_sequences', 'num_beams', 'length_penalty', 'repetition_penalty'] 

GenerationMixin.generate contains the implementation of most HF generative models' forward pass. During GenerationMixin.generate, GenerationMixin._validate_model_kwargs is called, and it raises an exception if not all model kwargs passed are used by the model: https://github.com/huggingface/transformers/blob/40821a247823b35d7ff10ba490d0d930fe8f5afa/src/transformers/generation/utils.py#L1380-L1384. This error only appears if we do a top-level compile (it works if we directly wrap the HF model class with torch.compile), see repro below.

monkey patching some prints, the difference seems to be from different model_kwargs:

# eager
model_kwargs={
 'attention_mask': ...,
 'd_vector': None,
 'input_tokens': None,
 'voice_dirs': None}
model_args={'attention_mask',
 'encoder_attention_mask',
 'encoder_hidden_states',
 'head_mask',
 'input_ids',
 'inputs_embeds',
 'kwargs',
 'labels',
 'output_attentions',
 'output_hidden_states',
 'past_key_values',
 'position_ids',
 'return_dict',
 'token_type_ids',
 'use_cache'}

# compile 
model_kwargs={
 'attention_mask': ...,
 'bos_token_id': 1024,
 'd_vector': None,
 'do_sample': True,
 'eos_token_id': 1025,
 'input_tokens': None,
 'length_penalty': 1.0,
 'max_length': 650,
 'num_beams': 1,
 'num_return_sequences': 1,
 'output_attentions': False,
 'pad_token_id': 1025,
 'repetition_penalty': 5.0,
 'temperature': 0.75,
 'top_k': 50,
 'top_p': 0.85,
 'voice_dirs': None}
model_args={
 'attention_mask',
 'encoder_attention_mask',
 'encoder_hidden_states',
 'head_mask',
 'input_ids',
 'inputs_embeds',
 'kwargs',
 'labels',
 'output_attentions',
 'output_hidden_states',
 'past_key_values',
 'position_ids',
 'return_dict',
 'token_type_ids',
 'use_cache'}

Repro
install repo via frozen_requirements.txt: https://github.com/xmfan/coqui-ai-TTS/tree/empathy

import torch
from TTS.api import TTS
import time

# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Init TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)

# Run TTS
def fn():
    tts.tts(text="Hello from XTTS2. I am being tested for the torch.compile User Empathy Day on Nov 20th 2024.", speaker_wav="en_sample.wav", language="en")

@torch.compile(backend="eager")
def warmup(its=5):
    for i in range(its):
        start = time.time()
        fn()
        duration = time.time() - start
        print(f"warm up i={i} took {duration}s")

warmup()

Versions

install the repo using frozen_requirements.txt

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions