Skip to content

Commit 2bef343

Browse files
[Flax] Correct all return tensors to numpy (huggingface#13307)
* fix_torch_device_generate_test * remove @ * finish find and replace
1 parent 8aa67fc commit 2bef343

File tree

10 files changed

+11
-11
lines changed

10 files changed

+11
-11
lines changed

examples/research_projects/jax-projects/big_bird/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def evaluate(example):
110110
inputs = tokenizer(
111111
example["question"],
112112
example["context"],
113-
return_tensors="jax",
113+
return_tensors="np",
114114
max_length=4096,
115115
padding="max_length",
116116
truncation=True,

src/transformers/file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1121,7 +1121,7 @@ def _prepare_output_docstrings(output_type, config_class):
11211121
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
11221122
>>> model = {model_class}.from_pretrained('{checkpoint}')
11231123
1124-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1124+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
11251125
>>> outputs = model(**inputs)
11261126
11271127
>>> # retrieve logts for next token

src/transformers/generation_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def generate(
231231
>>> model = FlaxAutoModelForCausalLM.from_pretrained("distilgpt2")
232232
>>> input_context = "The dog"
233233
>>> # encode input context
234-
>>> input_ids = tokenizer(input_context, return_tensors="jax").input_ids
234+
>>> input_ids = tokenizer(input_context, return_tensors="np").input_ids
235235
>>> # generate candidates using sampling
236236
>>> outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True)
237237
>>> print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))

src/transformers/models/bert/modeling_flax_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
757757
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
758758
>>> model = FlaxBertForPreTraining.from_pretrained('bert-base-uncased')
759759
760-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
760+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
761761
>>> outputs = model(**inputs)
762762
763763
>>> prediction_logits = outputs.prediction_logits

src/transformers/models/big_bird/modeling_flax_big_bird.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1567,7 +1567,7 @@ class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel):
15671567
>>> tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
15681568
>>> model = FlaxBigBirdForPreTraining.from_pretrained('google/bigbird-roberta-base')
15691569
1570-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1570+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
15711571
>>> outputs = model(**inputs)
15721572
15731573
>>> prediction_logits = outputs.prediction_logits

src/transformers/models/electra/modeling_flax_electra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel):
761761
>>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
762762
>>> model = FlaxElectraForPreTraining.from_pretrained('google/electra-small-discriminator')
763763
764-
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
764+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
765765
>>> outputs = model(**inputs)
766766
767767
>>> prediction_logits = outputs.logits

src/transformers/models/vit/modeling_flax_vit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ class FlaxViTModel(FlaxViTPreTrainedModel):
512512
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
513513
>>> model = FlaxViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
514514
515-
>>> inputs = feature_extractor(images=image, return_tensors="jax")
515+
>>> inputs = feature_extractor(images=image, return_tensors="np")
516516
>>> outputs = model(**inputs)
517517
>>> last_hidden_states = outputs.last_hidden_state
518518
"""
@@ -592,7 +592,7 @@ class FlaxViTForImageClassification(FlaxViTPreTrainedModel):
592592
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
593593
>>> model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
594594
595-
>>> inputs = feature_extractor(images=image, return_tensors="jax")
595+
>>> inputs = feature_extractor(images=image, return_tensors="np")
596596
>>> outputs = model(**inputs)
597597
>>> logits = outputs.logits
598598

tests/test_modeling_flax_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def test_cnn_summarization_same_as_fairseq(self):
453453
padding="max_length",
454454
truncation_strategy="only_first",
455455
truncation=True,
456-
return_tensors="jax",
456+
return_tensors="np",
457457
)
458458

459459
self.assertEqual(1024, dct["input_ids"].shape[1])

tests/test_modeling_flax_gpt2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_use_cache_forward_with_attn_mask(self):
213213
@slow
214214
def test_batch_generation(self):
215215
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="</s>", padding_side="left")
216-
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
216+
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
217217

218218
model = FlaxGPT2LMHeadModel.from_pretrained("gpt2")
219219
model.do_sample = False

tests/test_modeling_flax_gpt_neo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_use_cache_forward_with_attn_mask(self):
204204
@slow
205205
def test_batch_generation(self):
206206
tokenizer = GPT2Tokenizer.from_pretrained("gpt2", pad_token="<|endoftext|>", padding_side="left")
207-
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="jax", padding=True, truncation=True)
207+
inputs = tokenizer(["Hello this is a long string", "Hey"], return_tensors="np", padding=True, truncation=True)
208208

209209
model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M")
210210
model.do_sample = False

0 commit comments

Comments
 (0)