Skip to content

Commit b100bb5

Browse files
committed
add mbart
1 parent 749756a commit b100bb5

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

src/transformers/models/mbart/tokenization_mbart.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
315315
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
316316
)
317317

318-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
318+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
319319
copyfile(self.vocab_file, out_vocab_file)
320+
elif not os.path.isfile(self.vocab_file):
321+
with open(out_vocab_file, "wb") as fi:
322+
content_spiece_model = self.sp_model.serialized_model_proto()
323+
fi.write(content_spiece_model)
320324

321325
return (out_vocab_file,)
322326

tests/test_tokenization_mbart.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
3939
tokenizer_class = MBartTokenizer
4040
rust_tokenizer_class = MBartTokenizerFast
4141
test_rust_tokenizer = True
42+
test_sentencepiece = True
4243

4344
def setUp(self):
4445
super().setUp()

0 commit comments

Comments
 (0)