Skip to content

Commit e7ed7ff

Browse files
Narsilnjafer
andauthored
Adding support for multiple mask tokens. (huggingface#14716)
* Adding support for multiple mask tokens. - Original implem: huggingface#10222 Co-authored-by: njafer <naveen.jafer@oracle.com> * In order to accomodate optionally multimodal models like Perceiver we add information to the tasks to specify tasks where we know for sure if we need the tokenizer/feature_extractor or not. * Adding info in the documentation about multi masks. + marked as experimental. * Add a copy() to prevent overriding the same tensor over and over. * Fixup. * Adding small test for multi mask with real values.. Co-authored-by: njafer <naveen.jafer@oracle.com>
1 parent 2a606f9 commit e7ed7ff

File tree

3 files changed

+116
-36
lines changed

3 files changed

+116
-36
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,21 @@
125125
"tf": (),
126126
"pt": (AutoModelForAudioClassification,) if is_torch_available() else (),
127127
"default": {"model": {"pt": "superb/wav2vec2-base-superb-ks"}},
128+
"type": "audio",
128129
},
129130
"automatic-speech-recognition": {
130131
"impl": AutomaticSpeechRecognitionPipeline,
131132
"tf": (),
132133
"pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),
133134
"default": {"model": {"pt": "facebook/wav2vec2-base-960h"}},
135+
"type": "multimodal",
134136
},
135137
"feature-extraction": {
136138
"impl": FeatureExtractionPipeline,
137139
"tf": (TFAutoModel,) if is_tf_available() else (),
138140
"pt": (AutoModel,) if is_torch_available() else (),
139141
"default": {"model": {"pt": "distilbert-base-cased", "tf": "distilbert-base-cased"}},
142+
"type": "multimodal",
140143
},
141144
"text-classification": {
142145
"impl": TextClassificationPipeline,
@@ -148,6 +151,7 @@
148151
"tf": "distilbert-base-uncased-finetuned-sst-2-english",
149152
},
150153
},
154+
"type": "text",
151155
},
152156
"token-classification": {
153157
"impl": TokenClassificationPipeline,
@@ -159,6 +163,7 @@
159163
"tf": "dbmdz/bert-large-cased-finetuned-conll03-english",
160164
},
161165
},
166+
"type": "text",
162167
},
163168
"question-answering": {
164169
"impl": QuestionAnsweringPipeline,
@@ -167,6 +172,7 @@
167172
"default": {
168173
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
169174
},
175+
"type": "text",
170176
},
171177
"table-question-answering": {
172178
"impl": TableQuestionAnsweringPipeline,
@@ -179,18 +185,21 @@
179185
"tf": "google/tapas-base-finetuned-wtq",
180186
},
181187
},
188+
"type": "text",
182189
},
183190
"fill-mask": {
184191
"impl": FillMaskPipeline,
185192
"tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),
186193
"pt": (AutoModelForMaskedLM,) if is_torch_available() else (),
187194
"default": {"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"}},
195+
"type": "text",
188196
},
189197
"summarization": {
190198
"impl": SummarizationPipeline,
191199
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
192200
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
193201
"default": {"model": {"pt": "sshleifer/distilbart-cnn-12-6", "tf": "t5-small"}},
202+
"type": "text",
194203
},
195204
# This task is a special case as it's parametrized by SRC, TGT languages.
196205
"translation": {
@@ -202,18 +211,21 @@
202211
("en", "de"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
203212
("en", "ro"): {"model": {"pt": "t5-base", "tf": "t5-base"}},
204213
},
214+
"type": "text",
205215
},
206216
"text2text-generation": {
207217
"impl": Text2TextGenerationPipeline,
208218
"tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),
209219
"pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),
210220
"default": {"model": {"pt": "t5-base", "tf": "t5-base"}},
221+
"type": "text",
211222
},
212223
"text-generation": {
213224
"impl": TextGenerationPipeline,
214225
"tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),
215226
"pt": (AutoModelForCausalLM,) if is_torch_available() else (),
216227
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
228+
"type": "text",
217229
},
218230
"zero-shot-classification": {
219231
"impl": ZeroShotClassificationPipeline,
@@ -224,33 +236,48 @@
224236
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
225237
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
226238
},
239+
"type": "text",
227240
},
228241
"conversational": {
229242
"impl": ConversationalPipeline,
230243
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
231244
"pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),
232245
"default": {"model": {"pt": "microsoft/DialoGPT-medium", "tf": "microsoft/DialoGPT-medium"}},
246+
"type": "text",
233247
},
234248
"image-classification": {
235249
"impl": ImageClassificationPipeline,
236250
"tf": (),
237251
"pt": (AutoModelForImageClassification,) if is_torch_available() else (),
238252
"default": {"model": {"pt": "google/vit-base-patch16-224"}},
253+
"type": "image",
239254
},
240255
"image-segmentation": {
241256
"impl": ImageSegmentationPipeline,
242257
"tf": (),
243258
"pt": (AutoModelForImageSegmentation,) if is_torch_available() else (),
244259
"default": {"model": {"pt": "facebook/detr-resnet-50-panoptic"}},
260+
"type": "image",
245261
},
246262
"object-detection": {
247263
"impl": ObjectDetectionPipeline,
248264
"tf": (),
249265
"pt": (AutoModelForObjectDetection,) if is_torch_available() else (),
250266
"default": {"model": {"pt": "facebook/detr-resnet-50"}},
267+
"type": "image",
251268
},
252269
}
253270

271+
NO_FEATURE_EXTRACTOR_TASKS = set()
272+
NO_TOKENIZER_TASKS = set()
273+
for task, values in SUPPORTED_TASKS.items():
274+
if values["type"] == "text":
275+
NO_FEATURE_EXTRACTOR_TASKS.add(task)
276+
elif values["type"] in {"audio", "image"}:
277+
NO_TOKENIZER_TASKS.add(task)
278+
elif values["type"] != "multimodal":
279+
raise ValueError(f"SUPPORTED_TASK {task} contains invalid type {values['type']}")
280+
254281

255282
def get_supported_tasks() -> List[str]:
256283
"""
@@ -528,12 +555,14 @@ def pipeline(
528555
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
529556
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
530557

531-
if task in {"audio-classification", "image-classification"}:
558+
if task in NO_TOKENIZER_TASKS:
532559
# These will never require a tokenizer.
533560
# the model on the other hand might have a tokenizer, but
534561
# the files could be missing from the hub, instead of failing
535562
# on such repos, we just force to not load it.
536563
load_tokenizer = False
564+
if task in NO_FEATURE_EXTRACTOR_TASKS:
565+
load_feature_extractor = False
537566

538567
if load_tokenizer:
539568
# Try to infer tokenizer from model or config name (if provided as str)

src/transformers/pipelines/fill_mask.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ class FillMaskPipeline(Pipeline):
4444
4545
.. note::
4646
47-
This pipeline only works for inputs with exactly one token masked.
47+
This pipeline only works for inputs with exactly one token masked. Experimental: We added support for multiple
48+
masks. The returned values are raw model output, and correspond to disjoint probabilities where one might
49+
expect joint probabilities (See `discussion <https://github.com/huggingface/transformers/pull/10222>`__).
4850
"""
4951

5052
def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
@@ -59,13 +61,7 @@ def get_masked_index(self, input_ids: GenericTensor) -> np.ndarray:
5961
def _ensure_exactly_one_mask_token(self, input_ids: GenericTensor) -> np.ndarray:
6062
masked_index = self.get_masked_index(input_ids)
6163
numel = np.prod(masked_index.shape)
62-
if numel > 1:
63-
raise PipelineException(
64-
"fill-mask",
65-
self.model.base_model_prefix,
66-
f"More than one mask_token ({self.tokenizer.mask_token}) is not supported",
67-
)
68-
elif numel < 1:
64+
if numel < 1:
6965
raise PipelineException(
7066
"fill-mask",
7167
self.model.base_model_prefix,
@@ -98,46 +94,53 @@ def postprocess(self, model_outputs, top_k=5, target_ids=None):
9894
top_k = target_ids.shape[0]
9995
input_ids = model_outputs["input_ids"][0]
10096
outputs = model_outputs["logits"]
101-
result = []
10297

10398
if self.framework == "tf":
104-
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()
99+
masked_index = tf.where(input_ids == self.tokenizer.mask_token_id).numpy()[:, 0]
105100

106-
# Fill mask pipeline supports only one ${mask_token} per sample
101+
outputs = outputs.numpy()
107102

108-
logits = outputs[0, masked_index.item(), :]
109-
probs = tf.nn.softmax(logits)
103+
logits = outputs[0, masked_index, :]
104+
probs = tf.nn.softmax(logits, axis=-1)
110105
if target_ids is not None:
111-
probs = tf.gather_nd(probs, tf.reshape(target_ids, (-1, 1)))
106+
probs = tf.gather_nd(tf.squeeze(probs, 0), target_ids.reshape(-1, 1))
107+
probs = tf.expand_dims(probs, 0)
112108

113109
topk = tf.math.top_k(probs, k=top_k)
114110
values, predictions = topk.values.numpy(), topk.indices.numpy()
115111
else:
116-
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False)
112+
masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
117113
# Fill mask pipeline supports only one ${mask_token} per sample
118114

119-
logits = outputs[0, masked_index.item(), :]
120-
probs = logits.softmax(dim=0)
115+
logits = outputs[0, masked_index, :]
116+
probs = logits.softmax(dim=-1)
121117
if target_ids is not None:
122118
probs = probs[..., target_ids]
123119

124120
values, predictions = probs.topk(top_k)
125121

126-
for v, p in zip(values.tolist(), predictions.tolist()):
127-
tokens = input_ids.numpy()
128-
if target_ids is not None:
129-
p = target_ids[p].tolist()
130-
tokens[masked_index] = p
131-
# Filter padding out:
132-
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
133-
result.append(
134-
{
135-
"sequence": self.tokenizer.decode(tokens, skip_special_tokens=True),
136-
"score": v,
137-
"token": p,
138-
"token_str": self.tokenizer.decode(p),
139-
}
140-
)
122+
result = []
123+
single_mask = values.shape[0] == 1
124+
for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
125+
row = []
126+
for v, p in zip(_values, _predictions):
127+
# Copy is important since we're going to modify this array in place
128+
tokens = input_ids.numpy().copy()
129+
if target_ids is not None:
130+
p = target_ids[p].tolist()
131+
132+
tokens[masked_index[i]] = p
133+
# Filter padding out:
134+
tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
135+
# Originally we skip special tokens to give readable output.
136+
# For multi masks though, the other [MASK] would be removed otherwise
137+
# making the output look odd, so we add them back
138+
sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
139+
proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode(p), "sequence": sequence}
140+
row.append(proposition)
141+
result.append(row)
142+
if single_mask:
143+
return result[0]
141144
return result
142145

143146
def get_target_ids(self, targets, top_k=None):

tests/test_pipelines_fill_mask.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,32 @@ def test_small_model_pt(self):
104104
],
105105
)
106106

107+
outputs = unmasker("My name is <mask> <mask>", top_k=2)
108+
109+
self.assertEqual(
110+
nested_simplify(outputs, decimals=6),
111+
[
112+
[
113+
{
114+
"score": 2.2e-05,
115+
"token": 35676,
116+
"token_str": " Maul",
117+
"sequence": "<s>My name is Maul<mask></s>",
118+
},
119+
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name isELS<mask></s>"},
120+
],
121+
[
122+
{
123+
"score": 2.2e-05,
124+
"token": 35676,
125+
"token_str": " Maul",
126+
"sequence": "<s>My name is<mask> Maul</s>",
127+
},
128+
{"score": 2.2e-05, "token": 16416, "token_str": "ELS", "sequence": "<s>My name is<mask>ELS</s>"},
129+
],
130+
],
131+
)
132+
107133
@slow
108134
@require_torch
109135
def test_large_model_pt(self):
@@ -231,9 +257,6 @@ def run_pipeline_test(self, fill_masker, examples):
231257

232258
with self.assertRaises(ValueError):
233259
fill_masker([None])
234-
# Multiple masks
235-
with self.assertRaises(PipelineException):
236-
fill_masker(f"This is {tokenizer.mask_token} {tokenizer.mask_token}")
237260
# No mask_token is not supported
238261
with self.assertRaises(PipelineException):
239262
fill_masker("This is")
@@ -242,6 +265,7 @@ def run_pipeline_test(self, fill_masker, examples):
242265
self.run_test_targets(model, tokenizer)
243266
self.run_test_top_k_targets(model, tokenizer)
244267
self.fill_mask_with_duplicate_targets_and_top_k(model, tokenizer)
268+
self.fill_mask_with_multiple_masks(model, tokenizer)
245269

246270
def run_test_targets(self, model, tokenizer):
247271
vocab = tokenizer.get_vocab()
@@ -340,3 +364,27 @@ def fill_mask_with_duplicate_targets_and_top_k(self, model, tokenizer):
340364
# The target list contains duplicates, so we can't output more
341365
# than them
342366
self.assertEqual(len(outputs), 3)
367+
368+
def fill_mask_with_multiple_masks(self, model, tokenizer):
369+
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
370+
371+
outputs = fill_masker(
372+
f"This is a {tokenizer.mask_token} {tokenizer.mask_token} {tokenizer.mask_token}", top_k=2
373+
)
374+
self.assertEqual(
375+
outputs,
376+
[
377+
[
378+
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
379+
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
380+
],
381+
[
382+
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
383+
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
384+
],
385+
[
386+
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
387+
{"sequence": ANY(str), "score": ANY(float), "token": ANY(int), "token_str": ANY(str)},
388+
],
389+
],
390+
)

0 commit comments

Comments
 (0)