Skip to content

Commit 637e817

Browse files
authored
[Tests] Fix test (huggingface#15324)
* Fix Swin device * Remove print statement
1 parent e695470 commit 637e817

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

src/transformers/models/swin/modeling_swin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,9 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False):
437437
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
438438
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
439439

440+
if self.attn_mask is not None:
441+
self.attn_mask = self.attn_mask.to(hidden_states_windows.device)
442+
440443
self_attention_outputs = self.attention(
441444
hidden_states_windows,
442445
self.attn_mask,

tests/test_modeling_vilt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ def test_inference_natural_language_visual_reasoning(self):
595595

596596
# forward pass
597597
outputs = model(
598-
input_ids=encoding_1.input_ids,
599-
pixel_values=pixel_values,
598+
input_ids=encoding_1.input_ids.to(torch_device),
599+
pixel_values=pixel_values.to(torch_device),
600600
)
601601

602602
# verify the logits

tests/test_modeling_vit_mae.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,6 @@ def test_save_load(self):
327327
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
328328

329329
for model_class in self.all_model_classes:
330-
331-
print("Model class:", model_class)
332-
333330
model = model_class(config)
334331
model.to(torch_device)
335332
model.eval()

0 commit comments

Comments
 (0)