Skip to content

Commit 854260c

Browse files
Rocketknight1sguggersdwalker62aromans
authored
TF/Numpy variants for all DataCollator classes (huggingface#13105)
* Adding a TF variant of the DataCollatorForTokenClassification to get feedback * Added a Numpy variant and a post_init check to fail early if a missing import is found * Fixed call to Numpy variant * Added a couple more of the collators * Update src/transformers/data/data_collator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Fixes, style pass, finished DataCollatorForSeqToSeq * Added all the LanguageModeling DataCollators, except SOP and PermutationLanguageModeling * Adding DataCollatorForPermutationLanguageModeling * Style pass * Add missing `__call__` for PLM * Remove `post_init` checks for frameworks because the imports inside them were making us fail code quality checks * Remove unused imports * First attempt at some TF tests * A second attempt to make any of those tests actually work * TF tests, round three * TF tests, round four * TF tests, round five * TF tests, all enabled! * Style pass * Merging tests into `test_data_collator.py` * Merging tests into `test_data_collator.py` * Fixing up test imports * Fixing up test imports * Trying shuffling the conditionals around * Commenting out non-functional old tests * Completed all tests for all three frameworks * Style pass * Fixed test typo * Style pass * Move standard `__call__` method to mixin * Rearranged imports for `test_data_collator` * Fix data collator typo "torch" -> "pt" * Fixed the most embarrassingly obvious bug * Update src/transformers/data/data_collator.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Renaming mixin * Updating docs Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Dalton Walker <dalton_walker@icloud.com> Co-authored-by: Andrew Romans <andrew.romans@hotmail.com>
1 parent 74b3344 commit 854260c

File tree

5 files changed

+1370
-139
lines changed

5 files changed

+1370
-139
lines changed

docs/source/main_classes/data_collator.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,18 @@ DataCollatorForLanguageModeling
5454
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
5555

5656
.. autoclass:: transformers.data.data_collator.DataCollatorForLanguageModeling
57-
:members: mask_tokens
57+
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
5858

5959

6060
DataCollatorForWholeWordMask
6161
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6262

6363
.. autoclass:: transformers.data.data_collator.DataCollatorForWholeWordMask
64-
:members: mask_tokens
64+
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens
6565

6666

6767
DataCollatorForPermutationLanguageModeling
6868
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6969

7070
.. autoclass:: transformers.data.data_collator.DataCollatorForPermutationLanguageModeling
71-
:members: mask_tokens
71+
:members: numpy_mask_tokens, tf_mask_tokens, torch_mask_tokens

src/transformers/__init__.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@
8181
"xnli_processors",
8282
"xnli_tasks_num_labels",
8383
],
84+
"data.data_collator": [
85+
"DataCollator",
86+
"DataCollatorForLanguageModeling",
87+
"DataCollatorForPermutationLanguageModeling",
88+
"DataCollatorForSeq2Seq",
89+
"DataCollatorForSOP",
90+
"DataCollatorForTokenClassification",
91+
"DataCollatorForWholeWordMask",
92+
"DataCollatorWithPadding",
93+
"default_data_collator",
94+
],
8495
"feature_extraction_sequence_utils": ["BatchFeature", "SequenceFeatureExtractor"],
8596
"file_utils": [
8697
"CONFIG_NAME",
@@ -460,17 +471,6 @@
460471
if is_torch_available():
461472
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
462473
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
463-
_import_structure["data.data_collator"] = [
464-
"DataCollator",
465-
"DataCollatorForLanguageModeling",
466-
"DataCollatorForPermutationLanguageModeling",
467-
"DataCollatorForSeq2Seq",
468-
"DataCollatorForSOP",
469-
"DataCollatorForTokenClassification",
470-
"DataCollatorForWholeWordMask",
471-
"DataCollatorWithPadding",
472-
"default_data_collator",
473-
]
474474
_import_structure["data.datasets"] = [
475475
"GlueDataset",
476476
"GlueDataTrainingArguments",
@@ -1830,6 +1830,17 @@
18301830
xnli_processors,
18311831
xnli_tasks_num_labels,
18321832
)
1833+
from .data.data_collator import (
1834+
DataCollator,
1835+
DataCollatorForLanguageModeling,
1836+
DataCollatorForPermutationLanguageModeling,
1837+
DataCollatorForSeq2Seq,
1838+
DataCollatorForSOP,
1839+
DataCollatorForTokenClassification,
1840+
DataCollatorForWholeWordMask,
1841+
DataCollatorWithPadding,
1842+
default_data_collator,
1843+
)
18331844

18341845
# Feature Extractor
18351846
from .feature_extraction_utils import BatchFeature, SequenceFeatureExtractor
@@ -2174,17 +2185,6 @@
21742185
# Benchmarks
21752186
from .benchmark.benchmark import PyTorchBenchmark
21762187
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
2177-
from .data.data_collator import (
2178-
DataCollator,
2179-
DataCollatorForLanguageModeling,
2180-
DataCollatorForPermutationLanguageModeling,
2181-
DataCollatorForSeq2Seq,
2182-
DataCollatorForSOP,
2183-
DataCollatorForTokenClassification,
2184-
DataCollatorForWholeWordMask,
2185-
DataCollatorWithPadding,
2186-
default_data_collator,
2187-
)
21882188
from .data.datasets import (
21892189
GlueDataset,
21902190
GlueDataTrainingArguments,

0 commit comments

Comments
 (0)