Skip to content

Commit 546a91a

Browse files
authored
Fixing tests for Perceiver (huggingface#14739)
* Adding some slow test to check for perceiver at least from a high level. * Re-enabling fast tests for Perceiver ImageClassification. * Perceiver might try to run without Tokenizer (Fast doesn't exist) and with FeatureExtractor some text only pipelines. * Oops. * Adding a comment for `update_config_with_model_class`. * Remove `model_architecture` to get `tiny_config`. * Finalize rebase. * Smarter way to handle undefined FastTokenizer. * Remove old code. * Addressing some nits. * Don't instantiate `None`.
1 parent 322d416 commit 546a91a

File tree

4 files changed

+83
-19
lines changed

4 files changed

+83
-19
lines changed

src/transformers/models/perceiver/modeling_perceiver.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,7 @@ def forward(
12681268
output_hidden_states=None,
12691269
labels=None,
12701270
return_dict=None,
1271+
pixel_values=None,
12711272
):
12721273
r"""
12731274
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1296,6 +1297,10 @@ def forward(
12961297
>>> predicted_class_idx = logits.argmax(-1).item()
12971298
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
12981299
"""
1300+
if inputs is not None and pixel_values is not None:
1301+
raise ValueError("You cannot use both `inputs` and `pixel_values`")
1302+
elif inputs is None and pixel_values is not None:
1303+
inputs = pixel_values
12991304
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
13001305

13011306
outputs = self.perceiver(
@@ -1399,6 +1404,7 @@ def forward(
13991404
output_hidden_states=None,
14001405
labels=None,
14011406
return_dict=None,
1407+
pixel_values=None,
14021408
):
14031409
r"""
14041410
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
@@ -1427,6 +1433,10 @@ def forward(
14271433
>>> predicted_class_idx = logits.argmax(-1).item()
14281434
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
14291435
"""
1436+
if inputs is not None and pixel_values is not None:
1437+
raise ValueError("You cannot use both `inputs` and `pixel_values`")
1438+
elif inputs is None and pixel_values is not None:
1439+
inputs = pixel_values
14301440
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
14311441

14321442
outputs = self.perceiver(

src/transformers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,8 @@ def pipeline(
528528
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
529529
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
530530

531-
if task in {"audio-classification"}:
532-
# Audio classification will never require a tokenizer.
531+
if task in {"audio-classification", "image-classification"}:
532+
# These will never require a tokenizer.
533533
# the model on the other hand might have a tokenizer, but
534534
# the files could be missing from the hub, instead of failing
535535
# on such repos, we just force to not load it.

tests/test_pipelines_common.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,15 @@ def get_tiny_config_from_class(configuration_class):
7777
model_tester = model_tester_class(parent=None)
7878

7979
if hasattr(model_tester, "get_pipeline_config"):
80-
return model_tester.get_pipeline_config()
80+
config = model_tester.get_pipeline_config()
8181
elif hasattr(model_tester, "get_config"):
82-
return model_tester.get_config()
82+
config = model_tester.get_config()
8383
else:
84+
config = None
8485
logger.warning(f"Model tester {model_tester_class.__name__} has no `get_config()`.")
8586

87+
return config
88+
8689

8790
@lru_cache(maxsize=100)
8891
def get_tiny_tokenizer_from_checkpoint(checkpoint):
@@ -100,11 +103,17 @@ def get_tiny_tokenizer_from_checkpoint(checkpoint):
100103
return tokenizer
101104

102105

103-
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config):
106+
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config, feature_extractor_class):
104107
try:
105108
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint)
106109
except Exception:
107-
feature_extractor = None
110+
try:
111+
if feature_extractor_class is not None:
112+
feature_extractor = feature_extractor_class()
113+
else:
114+
feature_extractor = None
115+
except Exception:
116+
feature_extractor = None
108117
if hasattr(tiny_config, "image_size") and feature_extractor:
109118
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
110119

@@ -168,7 +177,9 @@ def test(self):
168177
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
169178
else:
170179
tokenizer = None
171-
feature_extractor = get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config)
180+
feature_extractor = get_tiny_feature_extractor_from_checkpoint(
181+
checkpoint, tiny_config, feature_extractor_class
182+
)
172183

173184
if tokenizer is None and feature_extractor is None:
174185
self.skipTest(
@@ -218,6 +229,13 @@ def data(n):
218229
if not tokenizer_classes:
219230
# We need to test even if there are no tokenizers.
220231
tokenizer_classes = [None]
232+
else:
233+
# Remove the non defined tokenizers
234+
# ByT5 and Perceiver are bytes-level and don't define
235+
# FastTokenizer, we can just ignore those.
236+
tokenizer_classes = [
237+
tokenizer_class for tokenizer_class in tokenizer_classes if tokenizer_class is not None
238+
]
221239

222240
for tokenizer_class in tokenizer_classes:
223241
if tokenizer_class is not None:

tests/test_pipelines_image_classification.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@
1414

1515
import unittest
1616

17-
from transformers import (
18-
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
19-
PerceiverConfig,
20-
PreTrainedTokenizer,
21-
is_vision_available,
22-
)
17+
from transformers import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, PreTrainedTokenizer, is_vision_available
2318
from transformers.pipelines import ImageClassificationPipeline, pipeline
2419
from transformers.testing_utils import (
2520
is_pipeline_test,
@@ -28,6 +23,7 @@
2823
require_tf,
2924
require_torch,
3025
require_vision,
26+
slow,
3127
)
3228

3329
from .test_pipelines_common import ANY, PipelineTestCaseMeta
@@ -50,12 +46,7 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
5046
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
5147

5248
def get_test_pipeline(self, model, tokenizer, feature_extractor):
53-
if isinstance(model.config, PerceiverConfig):
54-
self.skipTest(
55-
"Perceiver model tester is defined with a language one, which has no feature_extractor, so the automated test cannot work here"
56-
)
57-
58-
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
49+
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
5950
examples = [
6051
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
6152
"http://images.cocodataset.org/val2017/000000039769.jpg",
@@ -167,3 +158,48 @@ def test_custom_tokenizer(self):
167158
image_classifier = pipeline("image-classification", model="lysandre/tiny-vit-random", tokenizer=tokenizer)
168159

169160
self.assertIs(image_classifier.tokenizer, tokenizer)
161+
162+
@slow
163+
@require_torch
164+
def test_perceiver(self):
165+
# Perceiver is not tested by `run_pipeline_test` properly.
166+
# That is because the type of feature_extractor and model preprocessor need to be kept
167+
# in sync, which is not the case in the current design
168+
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-conv")
169+
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
170+
self.assertEqual(
171+
nested_simplify(outputs, decimals=4),
172+
[
173+
{"score": 0.4385, "label": "tabby, tabby cat"},
174+
{"score": 0.321, "label": "tiger cat"},
175+
{"score": 0.0502, "label": "Egyptian cat"},
176+
{"score": 0.0137, "label": "crib, cot"},
177+
{"score": 0.007, "label": "radiator"},
178+
],
179+
)
180+
181+
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-fourier")
182+
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
183+
self.assertEqual(
184+
nested_simplify(outputs, decimals=4),
185+
[
186+
{"score": 0.5658, "label": "tabby, tabby cat"},
187+
{"score": 0.1309, "label": "tiger cat"},
188+
{"score": 0.0722, "label": "Egyptian cat"},
189+
{"score": 0.0707, "label": "remote control, remote"},
190+
{"score": 0.0082, "label": "computer keyboard, keypad"},
191+
],
192+
)
193+
194+
image_classifier = pipeline("image-classification", model="deepmind/vision-perceiver-learned")
195+
outputs = image_classifier("http://images.cocodataset.org/val2017/000000039769.jpg")
196+
self.assertEqual(
197+
nested_simplify(outputs, decimals=4),
198+
[
199+
{"score": 0.3022, "label": "tabby, tabby cat"},
200+
{"score": 0.2362, "label": "Egyptian cat"},
201+
{"score": 0.1856, "label": "tiger cat"},
202+
{"score": 0.0324, "label": "remote control, remote"},
203+
{"score": 0.0096, "label": "quilt, comforter, comfort, puff"},
204+
],
205+
)

0 commit comments

Comments
 (0)