@@ -499,7 +499,7 @@ def feed_forward_chunk(self, attention_output):
499
499
class PerceiverEncoder (nn .Module ):
500
500
"""The Perceiver Encoder: a scalable, fully attentional encoder."""
501
501
502
- def __init__ (self , config ):
502
+ def __init__ (self , config , kv_dim = None ):
503
503
super ().__init__ ()
504
504
self .config = config
505
505
@@ -523,7 +523,7 @@ def __init__(self, config):
523
523
v_channels = config .v_channels ,
524
524
num_heads = config .num_cross_attention_heads ,
525
525
q_dim = config .d_latents ,
526
- kv_dim = config . d_model ,
526
+ kv_dim = kv_dim ,
527
527
widening_factor = config .cross_attention_widening_factor ,
528
528
use_query_residual = config .use_query_residual ,
529
529
)
@@ -734,7 +734,9 @@ def __init__(
734
734
self .input_preprocessor = input_preprocessor
735
735
self .output_postprocessor = output_postprocessor
736
736
self .embeddings = PerceiverEmbeddings (config )
737
- self .encoder = PerceiverEncoder (config )
737
+ self .encoder = PerceiverEncoder (
738
+ config , kv_dim = input_preprocessor .num_channels if input_preprocessor is not None else config .d_model
739
+ )
738
740
self .decoder = decoder
739
741
740
742
# Initialize weights and apply final processing
@@ -782,16 +784,13 @@ def forward(
782
784
else :
783
785
modality_sizes = None
784
786
inputs_without_pos = None
787
+ if inputs .size ()[- 1 ] != self .config .d_model :
788
+ raise ValueError (
789
+ f"Last dimension of the inputs: { inputs .size ()[- 1 ]} doesn't correspond to config.d_model: { self .config .d_model } . "
790
+ "Make sure to set config.d_model appropriately."
791
+ )
785
792
786
- if inputs .size ()[- 1 ] != self .config .d_model :
787
- raise ValueError (
788
- f"Last dimension of the inputs: { inputs .size ()[- 1 ]} doesn't correspond to config.d_model: { self .config .d_model } . "
789
- "Please update config.d_model appropriately."
790
- )
791
- else :
792
- input_shape = inputs .size ()
793
-
794
- batch_size , seq_length , _ = input_shape
793
+ batch_size , seq_length , _ = inputs .size ()
795
794
device = inputs .device
796
795
797
796
# If no attention mask is provided, make them all ones
@@ -874,20 +873,22 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
874
873
def __init__ (self , config ):
875
874
super ().__init__ (config )
876
875
876
+ text_preprocessor = PerceiverTextPreprocessor (config )
877
+
877
878
trainable_position_encoding_kwargs_decoder = dict (
878
- num_channels = config . d_model , index_dims = config .max_position_embeddings
879
+ num_channels = text_preprocessor . num_channels , index_dims = config .max_position_embeddings
879
880
)
880
881
881
882
self .perceiver = PerceiverModel (
882
883
config ,
883
- input_preprocessor = PerceiverTextPreprocessor ( config ) ,
884
+ input_preprocessor = text_preprocessor ,
884
885
decoder = PerceiverBasicDecoder (
885
886
config ,
886
887
output_num_channels = config .d_latents ,
887
888
output_index_dims = config .max_position_embeddings , # we need to define the seq_len of the inputs beforehand
888
- num_channels = config . d_model ,
889
+ num_channels = text_preprocessor . num_channels ,
889
890
qk_channels = 8 * 32 ,
890
- v_channels = config . d_model ,
891
+ v_channels = text_preprocessor . num_channels ,
891
892
num_heads = 8 ,
892
893
use_query_residual = False ,
893
894
final_project = False ,
@@ -1502,22 +1503,24 @@ def __init__(self, config):
1502
1503
concat_pos = True , max_resolution = config .train_size , num_bands = 64 , sine_only = False
1503
1504
)
1504
1505
1506
+ image_preprocessor = PerceiverImagePreprocessor (
1507
+ config ,
1508
+ prep_type = "patches" ,
1509
+ spatial_downsample = 1 ,
1510
+ conv_after_patching = True ,
1511
+ conv_after_patching_in_channels = 54 ,
1512
+ temporal_downsample = 2 ,
1513
+ position_encoding_type = "fourier" ,
1514
+ # position_encoding_kwargs
1515
+ fourier_position_encoding_kwargs = fourier_position_encoding_kwargs_preprocessor ,
1516
+ )
1517
+
1505
1518
self .perceiver = PerceiverModel (
1506
1519
config ,
1507
- input_preprocessor = PerceiverImagePreprocessor (
1508
- config ,
1509
- prep_type = "patches" ,
1510
- spatial_downsample = 1 ,
1511
- conv_after_patching = True ,
1512
- conv_after_patching_in_channels = 54 ,
1513
- temporal_downsample = 2 ,
1514
- position_encoding_type = "fourier" ,
1515
- # position_encoding_kwargs
1516
- fourier_position_encoding_kwargs = fourier_position_encoding_kwargs_preprocessor ,
1517
- ),
1520
+ input_preprocessor = image_preprocessor ,
1518
1521
decoder = PerceiverOpticalFlowDecoder (
1519
1522
config ,
1520
- num_channels = config . d_model ,
1523
+ num_channels = image_preprocessor . num_channels ,
1521
1524
output_image_shape = config .train_size ,
1522
1525
rescale_factor = 100.0 ,
1523
1526
# decoder kwargs
@@ -2631,6 +2634,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
2631
2634
2632
2635
def __init__ (self , config ):
2633
2636
super ().__init__ ()
2637
+ self .config = config
2634
2638
self .embeddings = nn .Embedding (num_embeddings = config .vocab_size , embedding_dim = config .d_model )
2635
2639
self .position_embeddings = nn .Embedding (config .max_position_embeddings , config .d_model )
2636
2640
0 commit comments