Skip to content

Commit ade7371

Browse files
SaulLusgugger
andauthored
improve saving strategy of sentencepiece tokenizer (huggingface#15328)
* add new test * add a feature to same the sentencepiece tokenizer model when the init file was deleted * update marian * update m2m_100 * fix marian * update speech to text * override test for layoutxlm * fix saving bartpho * remove harcoded values bartpho * special token string version * finish bartpho * override layoutxml test * add mbart * move special tokens list * format * Revert "format" This reverts commit 37a40df. * simplify list of string of special tokens * Re-write `self.fairseq_tokens_to_ids ` initialization logic with special tokens Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com> Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
1 parent 196cce6 commit ade7371

21 files changed

+204
-38
lines changed

src/transformers/models/albert/tokenization_albert.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
343343
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
344344
)
345345

346-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
346+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
347347
copyfile(self.vocab_file, out_vocab_file)
348+
elif not os.path.isfile(self.vocab_file):
349+
with open(out_vocab_file, "wb") as fi:
350+
content_spiece_model = self.sp_model.serialized_model_proto()
351+
fi.write(content_spiece_model)
348352

349353
return (out_vocab_file,)

src/transformers/models/bartpho/tokenization_bartpho.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,20 @@ def __init__(
157157
self.sp_model.Load(str(vocab_file))
158158

159159
# Load the reduced vocab
160-
self.fairseq_tokens_to_ids = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
160+
161+
# Keep order of special tokens for backward compatibility
162+
self.fairseq_tokens_to_ids = {}
163+
cnt = 0
164+
for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]:
165+
if str(token) not in self.fairseq_tokens_to_ids:
166+
self.fairseq_tokens_to_ids[str(token)] = cnt
167+
cnt += 1
161168
with open(monolingual_vocab_file, "r", encoding="utf-8") as f:
162169
for line in f.readlines():
163170
token = line.strip().split()[0]
164171
self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids)
165-
self.fairseq_tokens_to_ids["<mask>"] = len(self.fairseq_tokens_to_ids)
172+
if str(mask_token) not in self.fairseq_tokens_to_ids:
173+
self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids)
166174

167175
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
168176

@@ -278,7 +286,7 @@ def _convert_token_to_id(self, token):
278286
if token in self.fairseq_tokens_to_ids:
279287
return self.fairseq_tokens_to_ids[token]
280288
else:
281-
return self.fairseq_tokens_to_ids["<unk>"]
289+
return self.unk_token_id
282290

283291
def _convert_id_to_token(self, index):
284292
"""Converts an index (integer) in a token (str) using the vocab."""
@@ -301,10 +309,21 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
301309
(filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"],
302310
)
303311

304-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
312+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
305313
copyfile(self.vocab_file, out_vocab_file)
306-
307-
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(out_monolingual_vocab_file):
314+
elif not os.path.isfile(self.vocab_file):
315+
with open(out_vocab_file, "wb") as fi:
316+
content_spiece_model = self.sp_model.serialized_model_proto()
317+
fi.write(content_spiece_model)
318+
319+
if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(
320+
out_monolingual_vocab_file
321+
) and os.path.isfile(self.monolingual_vocab_file):
308322
copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file)
323+
elif not os.path.isfile(self.monolingual_vocab_file):
324+
with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp:
325+
for token in self.fairseq_tokens_to_ids:
326+
if token not in self.all_special_tokens:
327+
fp.write(f"{str(token)} \n")
309328

310329
return out_vocab_file, out_monolingual_vocab_file

src/transformers/models/bert_generation/tokenization_bert_generation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
160160
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
161161
)
162162

163-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
163+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
164164
copyfile(self.vocab_file, out_vocab_file)
165+
elif not os.path.isfile(self.vocab_file):
166+
with open(out_vocab_file, "wb") as fi:
167+
content_spiece_model = self.sp_model.serialized_model_proto()
168+
fi.write(content_spiece_model)
165169

166170
return (out_vocab_file,)

src/transformers/models/big_bird/tokenization_big_bird.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
189189
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
190190
)
191191

192-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
192+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
193193
copyfile(self.vocab_file, out_vocab_file)
194+
elif not os.path.isfile(self.vocab_file):
195+
with open(out_vocab_file, "wb") as fi:
196+
content_spiece_model = self.sp_model.serialized_model_proto()
197+
fi.write(content_spiece_model)
194198

195199
return (out_vocab_file,)
196200

src/transformers/models/camembert/tokenization_camembert.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
288288
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
289289
)
290290

291-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
291+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
292292
copyfile(self.vocab_file, out_vocab_file)
293+
elif not os.path.isfile(self.vocab_file):
294+
with open(out_vocab_file, "wb") as fi:
295+
content_spiece_model = self.sp_model.serialized_model_proto()
296+
fi.write(content_spiece_model)
293297

294298
return (out_vocab_file,)

src/transformers/models/fnet/tokenization_fnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
305305
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
306306
)
307307

308-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
308+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
309309
copyfile(self.vocab_file, out_vocab_file)
310+
elif not os.path.isfile(self.vocab_file):
311+
with open(out_vocab_file, "wb") as fi:
312+
content_spiece_model = self.sp_model.serialized_model_proto()
313+
fi.write(content_spiece_model)
310314

311315
return (out_vocab_file,)

src/transformers/models/layoutxlm/tokenization_layoutxlm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
331331
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
332332
)
333333

334-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
334+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
335335
copyfile(self.vocab_file, out_vocab_file)
336+
elif not os.path.isfile(self.vocab_file):
337+
with open(out_vocab_file, "wb") as fi:
338+
content_spiece_model = self.sp_model.serialized_model_proto()
339+
fi.write(content_spiece_model)
336340

337341
return (out_vocab_file,)
338342

src/transformers/models/m2m_100/tokenization_m2m_100.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
"""Tokenization classes for M2M100."""
1515
import json
16+
import os
1617
from contextlib import contextmanager
1718
from pathlib import Path
1819
from shutil import copyfile
@@ -312,8 +313,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
312313

313314
save_json(self.encoder, vocab_save_path)
314315

315-
if not spm_save_path.exists():
316+
if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
316317
copyfile(self.spm_file, spm_save_path)
318+
elif not os.path.isfile(self.spm_file):
319+
with open(spm_save_path, "wb") as fi:
320+
content_spiece_model = self.sp_model.serialized_model_proto()
321+
fi.write(content_spiece_model)
317322

318323
return (str(vocab_save_path), str(spm_save_path))
319324

src/transformers/models/marian/tokenization_marian.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import json
15+
import os
1616
import re
1717
import warnings
1818
from contextlib import contextmanager
@@ -23,7 +23,10 @@
2323
import sentencepiece
2424

2525
from ...tokenization_utils import PreTrainedTokenizer
26+
from ...utils import logging
27+
2628

29+
logger = logging.get_logger(__name__)
2730

2831
VOCAB_FILES_NAMES = {
2932
"source_spm": "source.spm",
@@ -277,21 +280,35 @@ def vocab_size(self) -> int:
277280
return len(self.encoder)
278281

279282
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
280-
save_dir = Path(save_directory)
281-
assert save_dir.is_dir(), f"{save_directory} should be a directory"
282-
save_json(
283-
self.encoder,
284-
save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]),
283+
if not os.path.isdir(save_directory):
284+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
285+
return
286+
saved_files = []
287+
out_vocab_file = os.path.join(
288+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
285289
)
286290

287-
for orig, f in zip(["source.spm", "target.spm"], self.spm_files):
288-
dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name)
289-
if not dest_path.exists():
290-
copyfile(f, save_dir / orig)
291-
292-
return tuple(
293-
save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names
294-
)
291+
save_json(self.encoder, out_vocab_file)
292+
saved_files.append(out_vocab_file)
293+
294+
for spm_save_filename, spm_orig_path, spm_model in zip(
295+
[VOCAB_FILES_NAMES["source_spm"], VOCAB_FILES_NAMES["target_spm"]],
296+
self.spm_files,
297+
[self.spm_source, self.spm_target],
298+
):
299+
spm_save_path = os.path.join(
300+
save_directory, (filename_prefix + "-" if filename_prefix else "") + spm_save_filename
301+
)
302+
if os.path.abspath(spm_orig_path) != os.path.abspath(spm_save_path) and os.path.isfile(spm_orig_path):
303+
copyfile(spm_orig_path, spm_save_path)
304+
saved_files.append(spm_save_path)
305+
elif not os.path.isfile(spm_orig_path):
306+
with open(spm_save_path, "wb") as fi:
307+
content_spiece_model = spm_model.serialized_model_proto()
308+
fi.write(content_spiece_model)
309+
saved_files.append(spm_save_path)
310+
311+
return tuple(saved_files)
295312

296313
def get_vocab(self) -> Dict:
297314
vocab = self.encoder.copy()

src/transformers/models/mbart/tokenization_mbart.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
315315
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
316316
)
317317

318-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
318+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
319319
copyfile(self.vocab_file, out_vocab_file)
320+
elif not os.path.isfile(self.vocab_file):
321+
with open(out_vocab_file, "wb") as fi:
322+
content_spiece_model = self.sp_model.serialized_model_proto()
323+
fi.write(content_spiece_model)
320324

321325
return (out_vocab_file,)
322326

src/transformers/models/mbart50/tokenization_mbart50.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
245245
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
246246
)
247247

248-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
248+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
249249
copyfile(self.vocab_file, out_vocab_file)
250+
elif not os.path.isfile(self.vocab_file):
251+
with open(out_vocab_file, "wb") as fi:
252+
content_spiece_model = self.sp_model.serialized_model_proto()
253+
fi.write(content_spiece_model)
250254

251255
return (out_vocab_file,)
252256

src/transformers/models/pegasus/tokenization_pegasus.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
285285
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
286286
)
287287

288-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
288+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
289289
copyfile(self.vocab_file, out_vocab_file)
290+
elif not os.path.isfile(self.vocab_file):
291+
with open(out_vocab_file, "wb") as fi:
292+
content_spiece_model = self.sp_model.serialized_model_proto()
293+
fi.write(content_spiece_model)
290294

291295
return (out_vocab_file,)

src/transformers/models/reformer/tokenization_reformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
167167
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
168168
)
169169

170-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
170+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
171171
copyfile(self.vocab_file, out_vocab_file)
172+
elif not os.path.isfile(self.vocab_file):
173+
with open(out_vocab_file, "wb") as fi:
174+
content_spiece_model = self.sp_model.serialized_model_proto()
175+
fi.write(content_spiece_model)
172176

173177
return (out_vocab_file,)

src/transformers/models/speech_to_text/tokenization_speech_to_text.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
"""Tokenization classes for Speech2Text."""
16-
1716
import json
17+
import os
1818
from pathlib import Path
1919
from shutil import copyfile
2020
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -260,8 +260,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
260260

261261
save_json(self.encoder, vocab_save_path)
262262

263-
if not spm_save_path.exists():
263+
if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file):
264264
copyfile(self.spm_file, spm_save_path)
265+
elif not os.path.isfile(self.spm_file):
266+
with open(spm_save_path, "wb") as fi:
267+
content_spiece_model = self.sp_model.serialized_model_proto()
268+
fi.write(content_spiece_model)
265269

266270
return (str(vocab_save_path), str(spm_save_path))
267271

src/transformers/models/t5/tokenization_t5.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
303303
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
304304
)
305305

306-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
306+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
307307
copyfile(self.vocab_file, out_vocab_file)
308-
logger.info(f"Copy vocab file to {out_vocab_file}")
308+
elif not os.path.isfile(self.vocab_file):
309+
with open(out_vocab_file, "wb") as fi:
310+
content_spiece_model = self.sp_model.serialized_model_proto()
311+
fi.write(content_spiece_model)
309312

310313
return (out_vocab_file,)

src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,12 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
302302
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
303303
)
304304

305-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
305+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
306306
copyfile(self.vocab_file, out_vocab_file)
307+
elif not os.path.isfile(self.vocab_file):
308+
with open(out_vocab_file, "wb") as fi:
309+
content_spiece_model = self.sp_model.serialized_model_proto()
310+
fi.write(content_spiece_model)
307311

308312
return (out_vocab_file,)
309313

src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
310310
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
311311
)
312312

313-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
313+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
314314
copyfile(self.vocab_file, out_vocab_file)
315+
elif not os.path.isfile(self.vocab_file):
316+
with open(out_vocab_file, "wb") as fi:
317+
content_spiece_model = self.sp_model.serialized_model_proto()
318+
fi.write(content_spiece_model)
315319

316320
return (out_vocab_file,)

src/transformers/models/xlnet/tokenization_xlnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,11 @@ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] =
342342
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
343343
)
344344

345-
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
345+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
346346
copyfile(self.vocab_file, out_vocab_file)
347+
elif not os.path.isfile(self.vocab_file):
348+
with open(out_vocab_file, "wb") as fi:
349+
content_spiece_model = self.sp_model.serialized_model_proto()
350+
fi.write(content_spiece_model)
347351

348352
return (out_vocab_file,)

tests/test_tokenization_common.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,33 @@ def test_pickle_subword_regularization_tokenizer(self) -> None:
394394
self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs)
395395
self.check_subword_sampling(tokenizer_new)
396396

397+
def test_save_sentencepiece_tokenizer(self) -> None:
398+
if not self.test_sentencepiece or not self.test_slow_tokenizer:
399+
return
400+
# We want to verify that we will be able to save the tokenizer even if the original files that were used to
401+
# build the tokenizer have been deleted in the meantime.
402+
text = "This is text to test the tokenizer."
403+
404+
tokenizer_slow_1 = self.get_tokenizer()
405+
encoding_tokenizer_slow_1 = tokenizer_slow_1(text)
406+
407+
tmpdirname_1 = tempfile.mkdtemp()
408+
tmpdirname_2 = tempfile.mkdtemp()
409+
410+
tokenizer_slow_1.save_pretrained(tmpdirname_1)
411+
tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1)
412+
encoding_tokenizer_slow_2 = tokenizer_slow_2(text)
413+
414+
shutil.rmtree(tmpdirname_1)
415+
tokenizer_slow_2.save_pretrained(tmpdirname_2)
416+
417+
tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2)
418+
encoding_tokenizer_slow_3 = tokenizer_slow_3(text)
419+
shutil.rmtree(tmpdirname_2)
420+
421+
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
422+
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
423+
397424
def test_model_input_names_signature(self):
398425
accepted_model_main_input_names = [
399426
"input_ids", # nlp models

0 commit comments

Comments
 (0)