Skip to content

Commit e926ea2

Browse files
authored
Improve perceiver (huggingface#14750)
* First draft * Improve docstring + clean up tests * Remove unused code * Add check in case one doesn't provide a preprocessor
1 parent 971e366 commit e926ea2

File tree

3 files changed

+34
-45
lines changed

3 files changed

+34
-45
lines changed

src/transformers/models/perceiver/configuration_perceiver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class PerceiverConfig(PretrainedConfig):
4242
d_latents (:obj:`int`, `optional`, defaults to 1280):
4343
Dimension of the latent embeddings.
4444
d_model (:obj:`int`, `optional`, defaults to 768):
45-
Dimension of the inputs.
45+
Dimension of the inputs. Should only be provided in case [`PerceiverTextPreprocessor`] is used or no
46+
preprocessor is provided.
4647
num_blocks (:obj:`int`, `optional`, defaults to 1):
4748
Number of blocks in the Transformer encoder.
4849
num_self_attends_per_block (:obj:`int`, `optional`, defaults to 26):

src/transformers/models/perceiver/modeling_perceiver.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def feed_forward_chunk(self, attention_output):
499499
class PerceiverEncoder(nn.Module):
500500
"""The Perceiver Encoder: a scalable, fully attentional encoder."""
501501

502-
def __init__(self, config):
502+
def __init__(self, config, kv_dim=None):
503503
super().__init__()
504504
self.config = config
505505

@@ -523,7 +523,7 @@ def __init__(self, config):
523523
v_channels=config.v_channels,
524524
num_heads=config.num_cross_attention_heads,
525525
q_dim=config.d_latents,
526-
kv_dim=config.d_model,
526+
kv_dim=kv_dim,
527527
widening_factor=config.cross_attention_widening_factor,
528528
use_query_residual=config.use_query_residual,
529529
)
@@ -734,7 +734,9 @@ def __init__(
734734
self.input_preprocessor = input_preprocessor
735735
self.output_postprocessor = output_postprocessor
736736
self.embeddings = PerceiverEmbeddings(config)
737-
self.encoder = PerceiverEncoder(config)
737+
self.encoder = PerceiverEncoder(
738+
config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
739+
)
738740
self.decoder = decoder
739741

740742
# Initialize weights and apply final processing
@@ -782,16 +784,13 @@ def forward(
782784
else:
783785
modality_sizes = None
784786
inputs_without_pos = None
787+
if inputs.size()[-1] != self.config.d_model:
788+
raise ValueError(
789+
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
790+
"Make sure to set config.d_model appropriately."
791+
)
785792

786-
if inputs.size()[-1] != self.config.d_model:
787-
raise ValueError(
788-
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
789-
"Please update config.d_model appropriately."
790-
)
791-
else:
792-
input_shape = inputs.size()
793-
794-
batch_size, seq_length, _ = input_shape
793+
batch_size, seq_length, _ = inputs.size()
795794
device = inputs.device
796795

797796
# If no attention mask is provided, make them all ones
@@ -874,20 +873,22 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
874873
def __init__(self, config):
875874
super().__init__(config)
876875

876+
text_preprocessor = PerceiverTextPreprocessor(config)
877+
877878
trainable_position_encoding_kwargs_decoder = dict(
878-
num_channels=config.d_model, index_dims=config.max_position_embeddings
879+
num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings
879880
)
880881

881882
self.perceiver = PerceiverModel(
882883
config,
883-
input_preprocessor=PerceiverTextPreprocessor(config),
884+
input_preprocessor=text_preprocessor,
884885
decoder=PerceiverBasicDecoder(
885886
config,
886887
output_num_channels=config.d_latents,
887888
output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
888-
num_channels=config.d_model,
889+
num_channels=text_preprocessor.num_channels,
889890
qk_channels=8 * 32,
890-
v_channels=config.d_model,
891+
v_channels=text_preprocessor.num_channels,
891892
num_heads=8,
892893
use_query_residual=False,
893894
final_project=False,
@@ -1502,22 +1503,24 @@ def __init__(self, config):
15021503
concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False
15031504
)
15041505

1506+
image_preprocessor = PerceiverImagePreprocessor(
1507+
config,
1508+
prep_type="patches",
1509+
spatial_downsample=1,
1510+
conv_after_patching=True,
1511+
conv_after_patching_in_channels=54,
1512+
temporal_downsample=2,
1513+
position_encoding_type="fourier",
1514+
# position_encoding_kwargs
1515+
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
1516+
)
1517+
15051518
self.perceiver = PerceiverModel(
15061519
config,
1507-
input_preprocessor=PerceiverImagePreprocessor(
1508-
config,
1509-
prep_type="patches",
1510-
spatial_downsample=1,
1511-
conv_after_patching=True,
1512-
conv_after_patching_in_channels=54,
1513-
temporal_downsample=2,
1514-
position_encoding_type="fourier",
1515-
# position_encoding_kwargs
1516-
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
1517-
),
1520+
input_preprocessor=image_preprocessor,
15181521
decoder=PerceiverOpticalFlowDecoder(
15191522
config,
1520-
num_channels=config.d_model,
1523+
num_channels=image_preprocessor.num_channels,
15211524
output_image_shape=config.train_size,
15221525
rescale_factor=100.0,
15231526
# decoder kwargs
@@ -2631,6 +2634,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
26312634

26322635
def __init__(self, config):
26332636
super().__init__()
2637+
self.config = config
26342638
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
26352639
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
26362640

tests/test_modeling_perceiver.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,14 @@ def prepare_config_and_inputs(self, model_class=None):
147147
if self.use_input_mask:
148148
input_mask = random_attention_mask([self.batch_size, self.seq_length])
149149
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
150-
config.d_model = 512
151150
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
152151
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
153-
config.d_model = 261
154152
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
155153
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
156-
config.d_model = 322
157154
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
158155
elif model_class.__name__ == "PerceiverForOpticalFlow":
159-
config.d_model = 322
160156
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
161157
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
162-
config.d_model = 409
163158
images = torch.randn(
164159
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
165160
device=torch_device,
@@ -211,8 +206,6 @@ def create_and_check_for_masked_lm(self, config, inputs, input_mask, sequence_la
211206
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
212207

213208
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
214-
# set num_labels
215-
config.num_labels = self.num_labels
216209
model = PerceiverForSequenceClassification(config=config)
217210
model.to(torch_device)
218211
model.eval()
@@ -222,9 +215,6 @@ def create_and_check_for_sequence_classification(self, config, inputs, input_mas
222215
def create_and_check_for_image_classification_learned(
223216
self, config, inputs, input_mask, sequence_labels, token_labels
224217
):
225-
# set d_model and num_labels
226-
config.d_model = 512
227-
config.num_labels = self.num_labels
228218
model = PerceiverForImageClassificationLearned(config=config)
229219
model.to(torch_device)
230220
model.eval()
@@ -234,9 +224,6 @@ def create_and_check_for_image_classification_learned(
234224
def create_and_check_for_image_classification_fourier(
235225
self, config, inputs, input_mask, sequence_labels, token_labels
236226
):
237-
# set d_model and num_labels
238-
config.d_model = 261
239-
config.num_labels = self.num_labels
240227
model = PerceiverForImageClassificationFourier(config=config)
241228
model.to(torch_device)
242229
model.eval()
@@ -246,9 +233,6 @@ def create_and_check_for_image_classification_fourier(
246233
def create_and_check_for_image_classification_conv(
247234
self, config, inputs, input_mask, sequence_labels, token_labels
248235
):
249-
# set d_model and num_labels
250-
config.d_model = 322
251-
config.num_labels = self.num_labels
252236
model = PerceiverForImageClassificationConvProcessing(config=config)
253237
model.to(torch_device)
254238
model.eval()

0 commit comments

Comments
 (0)