REALM: Retrieval-Augmented Language Model Pre-Training: Devlin Et Al. 2018 Liu Et Al. 2019 Raffel Et Al. 2019
REALM: Retrieval-Augmented Language Model Pre-Training: Devlin Et Al. 2018 Liu Et Al. 2019 Raffel Et Al. 2019
REALM: Retrieval-Augmented Language Model Pre-Training: Devlin Et Al. 2018 Liu Et Al. 2019 Raffel Et Al. 2019
Kelvin Guu * 1 Kenton Lee * 1 Zora Tung 1 Panupong Pasupat 1 Ming-Wei Chang 1
Abstract
Language model pre-training has been shown to
arXiv:2002.08909v1 [cs.CL] 10 Feb 2020
decide what knowledge to retrieve and use during inference. C URATED T REC) and compare to state-of-the-art Open-QA
Before making each prediction, the language model uses models, including both extremely large models that store
the retriever to retrieve documents1 from a large corpus knowledge implicitly (such as T5) as well as previous ap-
such as Wikipedia, and then attends over those documents proaches that also use a knowledge retriever to access ex-
to help inform its prediction. Learning this model end-to- ternal knowledge, but implement retrieval in a more heuris-
end requires backpropagating through a retrieval step that tic fashion (Lee et al., 2019; Min et al., 2019a; Asai et al.,
considers an entire corpus of textual knowledge, as shown 2019). REALM achieves new state-of-the-art results on all
in Figure 1. three benchmarks, significantly outperforming all previous
systems by 4-16% absolute accuracy. We also demonstrate
The key intuition of REALM is to train the retriever us-
qualitative benefits of REALM, including interpretability
ing a performance-based signal from unsupervised text:
and modularity.
a retrieval that improves the language model’s perplex-
ity is helpful and should be rewarded, while an un-
informative retrieval should be penalized. For exam- 2. Background
ple, in Figure 1, if the model needs to fill the blank
Language model pre-training The goal of language
in “the at the top of the pyramid”, the re-
model pre-training is to learn useful representations of lan-
triever should be rewarded for selecting a document con-
guage, usually from unlabeled text corpora. The resulting
taining “The pyramidion on top allows for less
pre-trained model can then be further trained (fine-tuned)
material higher up the pyramid”. We achieve this
for a downstream task of primary interest (in our case,
behavior by modeling our retrieve-then-predict approach
Open-QA), often leading to better generalization than train-
as a latent variable language model and optimizing the
ing from scratch (Dai & Le, 2015; Radford et al., 2019).
marginal likelihood.
We focus on the masked language model2 (MLM) variant
Incorporating a large-scale neural retrieval module during
of pre-training popularized by BERT (Devlin et al., 2018).
pre-training constitutes a significant computational chal-
In its basic form, an MLM is trained to predict the miss-
lenge, since the retriever must consider millions of candi-
ing tokens in an input text passage. Given an unlabeled
date documents for each pre-training step, and we must
pre-training corpus X (e.g., Wikipedia text), a training ex-
backpropagate through its decisions. To address this, we
ample (x, y) can be generated by randomly masking to-
structure the retriever such that the computation performed
kens in a sampled piece of text (e.g., x = “The [MASK]
for each document can be cached and asynchronously up-
is the currency [MASK] the UK”; y = (“pound”,
dated, and selection of the best documents can be formu-
“of”)). The model uses its representation of the masked
lated as Maximum Inner Product Search (MIPS).
input x to predict the token that should go in each mask.
Numerous prior works have demonstrated the bene- A good MLM must learn to encode syntactic and semantic
fit of adding a discrete retrieval step to neural net- information (e.g., to predict “of”) as well as some world
works (Miller et al., 2016; Chen et al., 2017), but did not knowledge (e.g., to predict “pound”).
apply the framework to language model pre-training and
employed non-learned retrievers to handle large-scale doc-
ument collections. In the language modeling literature, the Open-domain question answering (Open-QA) To mea-
k-Nearest Neighbor Language Model (Khandelwal et al., sure a model’s ability to incorporate world knowledge, we
2019) (kNN-LM) retrieves similar LM examples to im- need a downstream task where world knowledge is criti-
prove memorization. However, kNN-LM was not fine- cal. Perhaps one of the most knowledge-intensive tasks in
tuned for downstream tasks, perhaps because it is unclear natural language processing is open-domain question an-
how to adapt the retrieval mechanism: a kNN can only use swering (Open-QA): given a question x such as “What is
examples labeled for the target task—during fine-tuning, the currency of the UK?”, a model must output the
this precludes LM examples, which contain the desired correct answer string y, “pound”. The “open” part of Open-
world knowledge. In contrast, REALM’s retriever is de- QA refers to the fact that the model does not receive a pre-
signed to transfer to other tasks, and the retrieval is just identified document that is known to contain the answer,
text, not a labeled example. unlike traditional reading comprehension (RC) tasks such
as SQuAD (Rajpurkar et al., 2016; 2018). While RC mod-
We evaluate our approach by fine-tuning the mod-
els pre-trained with REALM on the task of Open- 1
We use the term “document” loosely to refer to a passage
domain Question Answering (Open-QA), one of the most from the knowledge corpus, not necessarily a whole article.
2
knowledge-intensive tasks in natural language process- Strictly speaking, MLM is not a standard language model,
ing. We evaluate on three popular Open-QA bench- since it does not define a distribution over the entire sequence
of tokens. In the paper we sometimes abuse the term “language
marks (NATURAL Q UESTIONS -O PEN, W EB Q UESTIONS, and model” slightly to make the phrase shorter.
REALM: Retrieval-Augmented Language Model Pre-Training
els comprehend a single document, Open-QA models must 3.2. Model architecture
retain knowledge from millions of documents, since a ques-
We now describe the two key components: the
tion could be about any of them.
neural knowledge retriever, which models p(z | x),
We focus on Open-QA systems that utilize a textual knowl- and the knowledge-augmented encoder, which models
edge corpus Z as the knowledge source. Many of these p(y | z, x).
systems employ a retrieval-based approach: given a ques-
tion x, retrieve potentially relevant documents z from Knowledge Retriever The retriever is defined using a
the corpus Z, and then extract an answer y from the dense inner product model:
documents (Brill et al., 2002; Chen et al., 2017; Lee et al.,
2019). Our approach, REALM, is inspired by this exp f (x, z)
paradigm and extends it to language model pre-training. p(z | x) = P ′
,
z ′ exp f (x, z )
Alternatively, some recent work has proposed generation-
f (x, z) = Embedinput(x)⊤ Embeddoc(z),
based systems that apply a sequence-to-sequence model on
x to directly generate y token-by-token (Lewis et al., 2019;
where Embedinput and Embeddoc are embedding functions
Raffel et al., 2019). We will compare against state-of-the-
that map x and z respectively to d-dimensional vectors.
art systems from both paradigms in our experiments.
The relevance score f (x, z) between x and z is defined as
the inner product of the vector embeddings. The retrieval
3. Approach distribution is the softmax over all relevance scores.
We start by formalizing REALM’s pre-training and fine- We implement the embedding functions using BERT-style
tuning tasks as a retrieve-then-predict generative process Transformers (Devlin et al., 2018). Following standard
in Section 3.1. Then in Section 3.2, we describe the model practices, we join spans of text by applying wordpiece tok-
architectures for each component of that process. In Sec- enization, separating them with [SEP] tokens, prefixing a
tion 3.3, we show how to implement REALM pre-training [CLS] token, and appending a final [SEP] token.
and fine-tuning by maximizing the likelihood of REALM’s
generative process. En route, we address important compu- joinBERT (x) = [CLS]x[SEP]
tational challenges, explain why training works, and also joinBERT (x1 , x2 ) = [CLS]x1 [SEP]x2 [SEP]
discuss strategies for injecting useful inductive biases. The
overall framework is illustrated in Figure 2. As in Devlin et al. (2018), we pass this into a Transformer,
which produces one vector for each token, including the
3.1. REALM’s generative process vector corresponding to [CLS] which is used as a “pooled”
representation of the sequence (denoted BERTCLS). Finally,
For both pre-training and fine-tuning, REALM takes some we perform a linear projection to reduce the dimensionality
input x and learns a distribution p(y | x) over possible out- of the vector, denoted as a projection matrix W:
puts y. For pre-training, the task is masked language mod-
eling: x is a sentence from a pre-training corpus X with Embedinput(x) = Winput BERTCLS (joinBERT (x))
some tokens masked out, and the model must predict the
value of those missing tokens, y. For fine-tuning, the task Embeddoc(z) = Wdoc BERTCLS(joinBERT (ztitle , zbody ))
is Open-QA: x is a question, and y is the answer.
where ztitle is the document’s title and zbody is its body. We
REALM decomposes p(y | x) into two steps: retrieve, then let θ denote all parameters associated with the retriever,
predict. Given an input x, we first retrieve possibly helpful which include the Transformer and projection matrices.
documents z from a knowledge corpus Z. We model this as
a sample from the distribution p(z | x). Then, we condition Knowledge-Augmented Encoder Given an input x and
on both the retrieved z and the original input x to generate a retrieved document z, the knowledge-augmented encoder
the output y—modeled as p(y | z, x). To obtain the overall defines p(y | z, x). We join x and z into a single sequence
likelihood of generating y, we treat z as a latent variable that we feed into a Transformer (distinct from the one used
and marginalize over all possible documents z, yielding in the retriever). This allows us to perform rich cross-
attention between x and z before predicting y. See Figure 1
for a concrete example.
At this stage, the architectures for pre-training and fine-
tuning differ slightly. For the masked language model pre-
training task, we must predict the original value of each
X
p(y | x) = p(y | z, x) p(z | x). (1)
z∈Z [MASK] token in x. To do so, we use the same masked
REALM: Retrieval-Augmented Language Model Pre-Training
Figure 2. The overall framework of REALM. Left: Unsupervised pre-training. The knowledge retriever and knowledge-augmented
encoder are jointly pre-trained on the unsupervised language modeling task. Right: Supervised fine-tuning. After the parameters of the
retriever (θ) and encoder (φ) have been pre-trained, they are then fine-tuned on a task of primary interest, using supervised examples.
language modeling (MLM) loss as in Devlin et al. (2018): The key computational
P challenge is that the marginal prob-
Jx ability p(y | x) = z∈Z p(y | x, z) p(z | x) involves a sum-
mation over all documents z in the knowledge corpus Z.
Y
p(y | z, x) = p(yj | z, x)
j=1
We approximate this by instead summing over the top k
documents with highest probability under p(z | x)—this is
p(yj | z, x) ∝ exp wj⊤ BERTMASK(j) (joinBERT (x, zbody ))
reasonable if most documents have near zero probability.
where BERTMASK(j) denotes the Transformer output vector Even with this approximation, we still need an efficient way
corresponding to the j th masked token, Jx is the total num- to find the top k documents. Note that the ordering of doc-
ber of [MASK] tokens in x, and wj is a learned word em- uments under p(z | x) is the same as under the relevance
bedding for token yj . score f (x, z) = Embedinput(x)⊤ Embeddoc(z), which is an
For Open-QA fine-tuning, we wish to produce the answer inner product. Thus, we can employ Maximum Inner Prod-
string y. Following previous reading comprehension work uct Search (MIPS) algorithms to find the approximate top k
(Rajpurkar et al., 2016; Seo et al., 2016; Lee et al., 2016; documents, using running time and storage space that scale
Clark & Gardner, 2017), we will assume that the answer sub-linearly with the number of documents (Ram & Gray,
y can be found as a contiguous sequence of tokens in some 2012; Shrivastava & Li, 2014; Shen et al., 2015).
document z. Let S(z, y) be the set of spans matching y in To employ MIPS, we must pre-compute Embeddoc(z) for
z. Then we can define p(y | z, x) as: every z ∈ Z and construct an efficient search index over
p(y | z, x) ∝
X
exp MLP hSTART(s) ; hEND(s)
these embeddings. However, this data structure will no
longer be consistent with p(z | x) if the parameters θ of
s∈S(z,y)
Embeddoc are later updated. Hence, the search index goes
hSTART(s) = BERTSTART(s)(joinBERT (x, zbody )), “stale” after every gradient update on θ.
hEND(s) = BERTEND(s)(joinBERT (x, zbody )),
Our solution is to “refresh” the index by asynchronously
where BERTSTART(s) and BERTEND(s) denote the Transformer re-embedding and re-indexing all documents every several
output vectors corresponding to the start and end tokens of hundred training steps. The MIPS index is slightly stale be-
span s, respectively, while MLP denotes a feed-forward neu- tween refreshes, but note that it is only used to select the
ral network. We will let φ denote all parameters associated top k documents. We recompute p(z | x) and its gradient,
with the knowledge-augmented encoder. using the fresh θ, for these top k documents after retriev-
ing them. In Section 4.5, we empirically demonstrate that
3.3. Training this procedure results in stable optimization, provided that
refreshes happen at a sufficiently frequent rate.
For both pre-training and fine-tuning, we train by maxi-
mizing the log-likelihood log p(y | x) of the correct out-
put y. Since both the knowledge retriever and knowledge- Implementing asynchronous MIPS refreshes We asyn-
augmented encoder are differentiable neural networks, we chronously refresh the MIPS index by running two jobs in
can compute the gradient of log p(y | x) (defined in Equa- parallel: a primary trainer job, which performs gradient
tion 1) with respect to the model parameters θ and φ, and updates on the parameters, and a secondary index builder
optimize using stochastic gradient descent. job, which embeds and indexes the documents. As shown
REALM: Retrieval-Augmented Language Model Pre-Training
Table 1. Test results on Open-QA benchmarks. The number of train/test examples are shown in paretheses below each benchmark.
Predictions are evaluated with exact match against any reference answer. Sparse retrieval denotes methods that use sparse features such
as TF-IDF and BM25. Our model, REALM, outperforms all existing systems.
NQ WQ CT
Name Architectures Pre-training # params
(79k/4k) (3k/2k) (1k /1k)
BERT-Baseline (Lee et al., 2019) Sparse Retr.+Transformer BERT 26.5 17.7 21.3 110m
T5 (base) (Roberts et al., 2020) Transformer Seq2Seq T5 (Multitask) 27.0 29.1 - 223m
T5 (large) (Roberts et al., 2020) Transformer Seq2Seq T5 (Multitask) 29.8 32.2 - 738m
T5 (11b) (Roberts et al., 2020) Transformer Seq2Seq T5 (Multitask) 34.5 37.4 - 11318m
DrQA (Chen et al., 2017) Sparse Retr.+DocReader N/A - 20.7 25.7 34m
HardEM (Min et al., 2019a) Sparse Retr.+Transformer BERT 28.1 - - 110m
GraphRetriever (Min et al., 2019b) GraphRetriever+Transformer BERT 31.8 31.6 - 110m
PathRetriever (Asai et al., 2019) PathRetriever+Transformer MLM 32.6 - - 110m
ORQA (Lee et al., 2019) Dense Retr.+Transformer ICT+BERT 33.3 36.4 30.1 330m
Ours (X = Wikipedia, Z = Wikipedia) Dense Retr.+Transformer REALM 39.2 40.2 46.8 330m
Ours (X = CC-News, Z = Wikipedia) Dense Retr.+Transformer REALM 40.4 40.7 42.9 330m
Table 3. An example where REALM utilizes retrieved documents to better predict masked tokens. It assigns much higher probability
(0.129) to the correct term, “Fermat”, compared to BERT. (Note that the blank corresponds to 3 BERT wordpieces.)
x: An equilateral triangle is easily constructed using a straightedge and compass, because 3 is a prime.
(a) BERT p(y = “Fermat” | x) = 1.1 × 10 −14
(No retrieval.)
(b) REALM p(y = “Fermat” | x, z) = 1.0 (Conditional probability with document z =“257 is . . . a Fermat prime.
Thus a regular polygon with 257 sides is constructible with compass . . . ”)
(c) REALM p(y = “Fermat” | x) = 0.129 (Marginal probability, marginalizing over top 8 retrieved documents.)
Encoder or Retriever We first aim to determine whether 5. Discussion and Related Work
REALM pre-training improves the retriever or the encoder,
or both. To do so, we can reset the parameters of either We previously discussed related methods for Open-QA.
the retriever or the encoder to their baseline state before Here we present several alternate ways of viewing REALM
REALM pre-training, and feed that into fine-tuning. Reset- that connect it to a broader set of ideas beyond Open-QA:
ting both the retriever and encoder reduces the system to
our main baseline, ORQA. We find that both the encoder Language modeling with corpus as context Language
and retriever benefit from REALM training separately, but representation models have been incorporating contexts of
the best result requires both components acting in unison. increasingly large scope when making predictions. Ex-
amples of this progression include models that condi-
tion on surrounding words (Mikolov et al., 2013a;b), sen-
tences (Kiros et al., 2015; Peters et al., 2018), and para-
Masking scheme We compare our salient span masking graphs (Radford et al., 2018; Devlin et al., 2018). We can
scheme (Section 3.4) with (1) random token masking in- view REALM as a generalization of the above work to the
troduced in BERT (Devlin et al., 2018) and (2) random next level of scope: the entire text corpus.
span masking proposed by SpanBERT (Joshi et al., 2019).
While such salient span masking has not been shown to
Retrieve-and-edit with learned retrieval In order to
be impactful in previous work with standard BERT train-
better explain the variance in the input text and en-
ing (Joshi et al., 2019), it is crucial for REALM. Intuitively,
able controllable generation, Guu et al. (2018) proposed
the latent variable learning relies heavily on the utility of re-
a language model with the retrieve-and-edit frame-
trieval and is therefore more sensitive to a consistent learn-
work (Hashimoto et al., 2018) that conditions on text with
ing signal.
high lexical overlap. REALM has a similar approach, ex-
cept that the model learns for itself which texts are most
useful for reducing perplexity. By jointly learning the re-
MIPS index refresh rate During pre-training, we run a triever, REALM has the capacity to depend on information
parallel process to re-embed corpus documents and rebuild beyond lexical overlap.
the MIPS index. This results in one index refresh per ap-
proximately 500 training steps. To demonstrate the impor- Scalable grounded neural memory The document in-
tance of frequent index refreshes, we compare against using dex can be viewed as a memory where the keys are
a slower refresh rate. The results in Table 2 suggests that the document embeddings. From this view, our work
a stale index can hurt model training, and further reducing share motivations with works such as product key mem-
this staleness could offer better optimization. ory (Lample et al., 2019), which enables sub-linear mem-
ory access in a memory network (Weston et al., 2014;
Graves et al., 2014; Sukhbaatar et al., 2015), allowing
Examples of retrieved documents Table 3 shows an these scalable memory layers to be integrated into large
example of the REALM masked language model predic- language models. One main difference is that our memo-
tion. In this example, “Fermat” is the correct word, and ries are grounded—each memory is associated with a docu-
REALM (row (c)) gives the word a much high probability ment rather than unnamed value vectors. This level of inter-
compared to the BERT model (row (a)). Since REALM pretability is crucial for applications like Open-QA, where
manages to retrieve some documents with a related fact users would require provenance for a predicted answer to
(row (b)), the marginalized probability of the correct an- be trustworthy.
swer dramatically increases. This shows that REALM is
able to retrieve document to fill in the masked word even Unsupervised Corpus Alignment In sequence-to-
though it is trained with unsupervised text only. sequence models with attention (Bahdanau et al., 2014),
REALM: Retrieval-Augmented Language Model Pre-Training
text is generated with latent selection of relevant tokens. Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K. Bert:
This results in a set of model-centric unsupervised align- Pre-training of deep bidirectional transformers for lan-
ments between target and source tokens. Analogously, guage understanding. arXiv preprint arXiv:1810.04805,
REALM also generates text with latent selection of 2018.
relevant documents. A by-product of our method is that
Graves, A., Wayne, G., and Danihelka, I. Neural turing
we offer a set of model-centric unsupervised alignments
machines. ArXiv, abs/1410.5401, 2014.
between text in the pre-training corpus X and knowledge
corpus Z. Guu, K., Hashimoto, T. B., Oren, Y., and Liang, P. Gen-
erating sentences by editing prototypes. Transactions
6. Future Work of the Association for Computational Linguistics, 6:437–
450, 2018.
The work presented here is the minimal instantiation of a
family of REALM-like approaches where a representation Hashimoto, T. B., Guu, K., Oren, Y., and Liang, P. S.
is pre-trained to perform reasoning over a large corpus of A retrieve-and-edit framework for predicting structured
knowledge on-the-fly during inference. We are particularly outputs. In Advances in Neural Information Processing
optimistic about generalizations of this work to (1) struc- Systems, pp. 10052–10062, 2018.
tured knowledge, which would result in a generalization of Joshi, M., Chen, D., Liu, Y., Weld, D. S., Zettlemoyer,
Peters et al. (2019) where we would also learn the decision L., and Levy, O. SpanBERT: Improving pre-training
of which entities are informative, (2) the multi-lingual set- by representing and predicting spans. arXiv preprint
ting, e.g., retrieving knowledge in a high-resource language arXiv:1907.10529, 2019.
to better represent text in a low-resource language, and (3)
the multi-modal setting, e.g., retrieving images or videos Khandelwal, U., Levy, O., Jurafsky, D., Zettlemoyer,
that can provide knowledge rarely observed in text. L., and Lewis, M. Generalization through memo-
rization: Nearest neighbor language models. ArXiv,
abs/1911.00172, 2019.
References
Kiros, R., Zhu, Y., Salakhutdinov, R. R., Zemel, R., Urta-
Asai, A., Hashimoto, K., Hajishirzi, H., Socher, R., and
sun, R., Torralba, A., and Fidler, S. Skip-thought vectors.
Xiong, C. Learning to retrieve reasoning paths over
In Advances in neural information processing systems,
wikipedia graph for question answering. arXiv preprint
pp. 3294–3302, 2015.
arXiv:1911.10470, 2019.
Kwiatkowski, T., Palomaki, J., Rhinehart, O., Collins, M.,
Bahdanau, D., Cho, K., and Bengio, Y. Neural machine
Parikh, A., Alberti, C., Epstein, D., Polosukhin, I., Kel-
translation by jointly learning to align and translate.
cey, M., Devlin, J., et al. Natural questions: a benchmark
arXiv preprint arXiv:1409.0473, 2014.
for question answering research. Transactions of the As-
Berant, J., Chou, A., Frostig, R., and Liang, P. Semantic sociation for Computational Linguistics, 2019.
parsing on freebase from question-answer pairs. In Pro-
Lample, G., Sablayrolles, A., Ranzato, M., Denoyer, L.,
ceedings of the 2013 Conference on Empirical Methods
and Jégou, H. Large memory layers with product keys.
in Natural Language Processing, pp. 1533–1544, 2013.
In Advances in Neural Information Processing Systems,
Brill, E., Dumais, S., and Banko, M. An analysis of the pp. 8546–8557, 2019.
askmsr question-answering system. In Empirical Meth-
Lee, K., Salant, S., Kwiatkowski, T., Parikh, A., Das,
ods in Natural Language Processing, 2002.
D., and Berant, J. Learning recurrent span representa-
Chen, D., Fisch, A., Weston, J., and Bordes, A. Read- tions for extractive question answering. arXiv preprint
ing wikipedia to answer open-domain questions. In Pro- arXiv:1611.01436, 2016.
ceedings of the 55th Annual Meeting of the Association
Lee, K., Chang, M.-W., and Toutanova, K. Latent re-
for Computational Linguistics (Volume 1: Long Papers),
trieval for weakly supervised open domain question an-
volume 1, pp. 1870–1879, 2017.
swering. In Proceedings of the Conference of Associa-
Clark, C. and Gardner, M. Simple and effective multi- tion for Computational Linguistics, 2019.
paragraph reading comprehension. In Annual Meeting
Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mo-
of the Association for Computational Linguistics, 2017.
hamed, A., Levy, O., Stoyanov, V., and Zettlemoyer, L.
Dai, A. M. and Le, Q. V. Semi-supervised sequence learn- Bart: Denoising sequence-to-sequence pre-training for
ing. In Advances in neural information processing sys- natural language generation, translation, and comprehen-
tems, pp. 3079–3087, 2015. sion. ArXiv, abs/1910.13461, 2019.
REALM: Retrieval-Augmented Language Model Pre-Training
Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D., Rajpurkar, P., Jia, R., and Liang, P. Know what you don’t
Levy, O., Lewis, M., Zettlemoyer, L., and Stoyanov, V. know: Unanswerable questions for squad. arXiv preprint
Roberta: A robustly optimized bert pretraining approach. arXiv:1806.03822, 2018.
arXiv preprint arXiv:1907.11692, 2019.
Ram, P. and Gray, A. G. Maximum inner-product search us-
Mikolov, T., Chen, K., Corrado, G., and Dean, J. Efficient ing cone trees. In Proceedings of the 18th ACM SIGKDD
estimation of word representations in vector space. arXiv international conference on Knowledge discovery and
preprint arXiv:1301.3781, 2013a. data mining, pp. 931–939, 2012.
Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., Roberts, A., Raffel, C., and Shazeer, N. How much knowl-
and Dean, J. Distributed representations of words and edge can you pack into the parameters of a language
phrases and their compositionality. In Advances in model? arXiv preprint arXiv:TBD, 2020.
neural information processing systems, pp. 3111–3119,
2013b. Robertson, S., Zaragoza, H., et al. The probabilistic rele-
vance framework: Bm25 and beyond. Foundations and
Miller, A., Fisch, A., Dodge, J., Karimi, A.-H., Bordes, A., Trends in Information Retrieval, 3(4):333–389, 2009.
and Weston, J. Key-value memory networks for directly
reading documents. arXiv preprint arXiv:1606.03126, Sang, E. T. K. and De Meulder, F. Introduction to the conll-
2016. 2003 shared task: Language-independent named entity
recognition. In Proceedings of the Seventh Conference
Min, S., Chen, D., Hajishirzi, H., and Zettlemoyer, L. A dis- on Natural Language Learning at HLT-NAACL 2003, pp.
crete hard em approach for weakly supervised question 142–147, 2003.
answering. arXiv preprint arXiv:1909.04849, 2019a.
Seo, M., Kembhavi, A., Farhadi, A., and Hajishirzi, H.
Min, S., Chen, D., Zettlemoyer, L., and Hajishirzi, Bidirectional attention flow for machine comprehension.
H. Knowledge guided text retrieval and reading In International Conference on Learning Representa-
for open domain question answering. arXiv preprint tions, 2016.
arXiv:1911.03868, 2019b.
Shen, F., Liu, W., Zhang, S., Yang, Y., and Tao Shen,
Peters, M. E., Neumann, M., Iyyer, M., Gardner, M., Clark, H. Learning binary codes for maximum inner product
C., Lee, K., and Zettlemoyer, L. Deep contextualized search. In Proceedings of the IEEE International Con-
word representations. In Proc. of NAACL, 2018. ference on Computer Vision, pp. 4148–4156, 2015.
Peters, M. E., Neumann, M., IV, R. L. L., Schwartz, R., Shrivastava, A. and Li, P. Asymmetric lsh (alsh) for sub-
Joshi, V., Singh, S., and Smith, N. A. Knowledge en- linear time maximum inner product search (mips). In
hanced contextual word representations, 2019. Advances in Neural Information Processing Systems, pp.
Petroni, F., Rocktäschel, T., Lewis, P., Bakhtin, A., Wu, Y., 2321–2329, 2014.
Miller, A. H., and Riedel, S. Language models as knowl-
Sukhbaatar, S., Weston, J., Fergus, R., et al. End-to-end
edge bases? arXiv preprint arXiv:1909.01066, 2019.
memory networks. In Advances in neural information
Radford, A., Narasimhan, K., Salimans, T., and Sutskever, processing systems, 2015.
I. Improving language understanding with unsupervised
Weston, J., Chopra, S., and Bordes, A. Memory networks.
learning. Technical report, OpenAI, 2018.
arXiv preprint arXiv:1410.3916, 2014.
Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., and
Sutskever, I. Language models are unsupervised multi-
task learners. OpenAI Blog, 2019.
Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S.,
Matena, M., Zhou, Y., Li, W., and Liu, P. J. Exploring
the limits of transfer learning with a unified text-to-text
transformer. arXiv preprint arXiv:1910.10683, 2019.
Rajpurkar, P., Zhang, J., Lopyrev, K., and Liang, P. Squad:
100,000+ questions for machine comprehension of text.
In Proceedings of the 2016 Conference on Empirical
Methods in Natural Language Processing, pp. 2383–
2392, 2016.
REALM: Retrieval-Augmented Language Model Pre-Training
A. Derivation of the gradient with respect to zero accuracy (i.e., p (y | z ′ , x) = 0). Under this set-
the knowledge retriever ting, p (z ∗ | y, x) = 1 (provided that p (z ∗ | x) is non-zero),
which causes the gradient to become
We compute the gradient of the REALM pre-training objec- X
tive (a log-likelihood) with respect to the parameters of the ∇ log p (y | x) = ∇f (x, z ∗ ) − p (z | x) ∇f (x, z)
knowledge retriever, θ: z
∗
= ∇ log p (z | x) .
∇ log p(y | x) = p(y | x)−1 ∇p(y | x)
From this, we see that gradient descent on the REALM ob-
X
= p(y | x)−1 p(y | z, x)∇p(z | x)
z
jective is equivalent to gradient descent on log p (z ∗ | x).
X This is none other than the typical maximum likelihood
= p(y | x)−1 p(y | z, x)p(z | x)∇ log p(z | x) training objective used in supervised learning, where z ∗ is
z
X the “gold” document.
= p(z | y, x)∇ log p(z | x),
z
C. Adapting to new knowledge
where the last line follows from applying conditional
An explicit retrieval system allows us to adapt to new
Bayes’ rule. We can then expand ∇ log p (z | x) as:
world knowledge simply by modifying the corpus docu-
exp f (x, z) ments. To demonstrate this ability, we replace the knowl-
∇ log p(z | x) = ∇ log P ′ edge corpus with a more recent version of Wikipedia cor-
z ′ exp f (x, z ) pus after pre-training is done. When the input query is
" #
X about a fact where the two corpora disagree, REALM can
= ∇ f (x, z) − log exp f (x, z ′ )
change the prediction to reflect the updated information,
z′
X as exemplified in Table 4. However, even with an ex-
= ∇f (x, z) − p(z ′ | x)∇f (x, z ′ ) plicit retrieval mechanism, the knowledge-augmented en-
z′ coder will still end up remembering some world knowl-
edge, making the prediction of some input sentences not
Plugging this back into the first set of equations yields:
updated with the new corpus. (For instance, the model pre-
dicts “Thatcher” for “
" #
X X is the prime minister
∇ log p (y | x) = p (z | y, x) ∇f (x, z) − p (z ′ | x) ∇f (x, z ′ )
of United Kingdom.” on both corpora, perhaps due to
z z′
X X the frequent mention of her name in Wikipedia articles.)
= p (z | y, x) ∇f (x, z) − p (z ′ | x) ∇f (x, z ′ )
z z′
D. Retrieval Utility
X
= [p (z | y, x) − p (z | x)] ∇f (x, z)
z
X p (y | z, x) p (z | x) The null document ∅ described in Section 3.4 provides a
= − p (z | x) ∇f (x, z)
z
p (y | x) way to measure the importance of a retrieved document z:
X p (y | z, x)
we define the retrieval utility (RU) of z for the masked
= − 1 p (z | x) ∇f (x, z).
z
p (y | x) input x as the difference between the log-likelihood of
the knowledge-augmented encoder when conditioning on
In the second line, we used the fact that the overall expres- z versus on ∅:
sion is an expectation with respect to p (z | y, x), and the
terms which depend on z ′ but not z can be moved out of RU(z | x) = log p(y | z, x) − log p(y | ∅, x). (2)
that expectation.
A negative RU shows that z is less useful for predicting y
B. Connection between REALM and than the null document. This could mean that z is irrelevant
to x, but could also mean that the masked tokens in x do
supervised learning
not require world knowledge to predict, or that the world
From the equations in Appendix A, we saw that knowledge is sufficiently commonplace it has been baked
X into the model’s parameters. In practice, we find that RU
∇ log p (y | x) = [p (z | y, x) − p (z | x)] ∇f (x, z). increases steadily over the course of pre-training, and is
z more predictive of good performance on the downstream
task of Open-QA than even the overall log-likelihood. An
Suppose that there exists one document z ∗ which causes example of how RU behaves over time and across different
the model to achieve perfect prediction accuracy (i.e., settings is in Figure 4.
p (y | z ∗ , x) = 1), while all other documents z ′ result in
REALM: Retrieval-Augmented Language Model Pre-Training
Table 4. An example where REALM adapts to the updated knowledge corpus. The Wikipedia page “Excellent Cadaver” was added in
2019, so the model was not about to recover the word when the knowledge corpus is outdated (2018). Interestingly, the same REALM
model pre-trained on the 2018 corpus is able to retrieve the document in the updated corpus (2020) and generate the correct token,
“Lawrence”.
0
0 50 100 150 200
Pre-training Steps (Thousands)
Figure 4. The Retrieval Utility (RU, described in Eq. 2) vs the number of pre-training steps. RU roughly estimates the “usefulness” of
retrieval. RU is impacted by the choice of masking and the number of pre-training steps.