Skip to content

Latest commit

 

History

History

kv-prediction

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

KV Prediction For Improved Time To First Token

KV Prediction is a method for improving the time to first token (TTFT) of transformer models. It uses a small "auxiliary" transformer network to process the prompt efficiently. It then uses the KV cache of the auxiliary network to predict the KV cache of a larger "base" network. The base network is then used for inference without the need to query the auxiliary model again during autoregressive generation. Our method creates a pareto-optimal efficiency-accuracy trade-off for TTFT compared to baselines on benchmark datasets. See our paper for details.

Training

We experiment with OpenELM models. Configs are located in the openelm/ subdirectory. We used multinode training jobs with 8 nodes and 8 H100 GPUs per node.

An example command for training on the i-th node is

export CFG_FILE="PATH_TO_KV_PREDICTION_MODEL_CONFIGURATION_FILE"
export RANK=<NODE_ID> * <NUM_GPUS_PER_NODE>
export WORLD_SIZE=<NUM_NODES> * <NUM_GPUS_PER_NODE>
corenet-train --common.config-file $CFG_FILE --ddp.rank $RANK --ddp.world-size $WORLD_SIZE --ddp.dist-url 'tcp://IP_OF_NODE0:FREEPORT'

Evaluation

We evaluate in the LM Eval Harness on commit 3196e907fa195b684470a913c7235ed7f08a4383. We use the prompt template in triviaqa-template.yaml, since we noticed that the default template added an extra question mark to the question.

Citation

If you find our work useful, please cite:

@misc{horton2024kvpredictionimprovedtime,
      title={KV Prediction for Improved Time to First Token},
      author={Maxwell Horton and Qingqing Cao and Chenfan Sun and Yanzi Jin and Sachin Mehta and Mohammad Rastegari and Moin Nabi},
      year={2024},
      eprint={2410.08391},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2410.08391},
}

@inproceedings{mehta2022cvnets, 
     author = {Mehta, Sachin and Abdolhosseini, Farzad and Rastegari, Mohammad}, 
     title = {CVNets: High Performance Library for Computer Vision}, 
     year = {2022}, 
     booktitle = {Proceedings of the 30th ACM International Conference on Multimedia}, 
     series = {MM '22} 
}