Skip to content

Commit 774760e

Browse files
authored
distilbert-flax (huggingface#13324)
* distilbert-flax * added missing self * docs fix * removed tied kernal extra init * updated docs * x -> hidden states * removed head_mask * removed from_pt, +FLAX * updated year
1 parent 0197746 commit 774760e

File tree

8 files changed

+1205
-4
lines changed

8 files changed

+1205
-4
lines changed

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ Flax), PyTorch, and/or TensorFlow.
357357
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
358358
| DETR ||||||
359359
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
360-
| DistilBERT ||||| |
360+
| DistilBERT ||||| |
361361
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
362362
| DPR ||||||
363363
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

docs/source/model_doc/distilbert.rst

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ Tips:
4444
- DistilBERT doesn't have options to select the input positions (:obj:`position_ids` input). This could be added if
4545
necessary though, just let us know if you need this option.
4646

47-
This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. The original code can be found
48-
:prefix_link:`here <examples/research-projects/distillation>`.
47+
This model was contributed by `victorsanh <https://huggingface.co/victorsanh>`__. This model jax version was
48+
contributed by `kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found :prefix_link:`here
49+
<examples/research-projects/distillation>`.
4950

5051

5152
DistilBertConfig
@@ -152,3 +153,45 @@ TFDistilBertForQuestionAnswering
152153

153154
.. autoclass:: transformers.TFDistilBertForQuestionAnswering
154155
:members: call
156+
157+
158+
FlaxDistilBertModel
159+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
160+
161+
.. autoclass:: transformers.FlaxDistilBertModel
162+
:members: __call__
163+
164+
165+
FlaxDistilBertForMaskedLM
166+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
167+
168+
.. autoclass:: transformers.FlaxDistilBertForMaskedLM
169+
:members: __call__
170+
171+
172+
FlaxDistilBertForSequenceClassification
173+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
174+
175+
.. autoclass:: transformers.FlaxDistilBertForSequenceClassification
176+
:members: __call__
177+
178+
179+
FlaxDistilBertForMultipleChoice
180+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
181+
182+
.. autoclass:: transformers.FlaxDistilBertForMultipleChoice
183+
:members: __call__
184+
185+
186+
FlaxDistilBertForTokenClassification
187+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
188+
189+
.. autoclass:: transformers.FlaxDistilBertForTokenClassification
190+
:members: __call__
191+
192+
193+
FlaxDistilBertForQuestionAnswering
194+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
195+
196+
.. autoclass:: transformers.FlaxDistilBertForQuestionAnswering
197+
:members: __call__

src/transformers/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,6 +1712,17 @@
17121712
"FlaxCLIPVisionPreTrainedModel",
17131713
]
17141714
)
1715+
_import_structure["models.distilbert"].extend(
1716+
[
1717+
"FlaxDistilBertForMaskedLM",
1718+
"FlaxDistilBertForMultipleChoice",
1719+
"FlaxDistilBertForQuestionAnswering",
1720+
"FlaxDistilBertForSequenceClassification",
1721+
"FlaxDistilBertForTokenClassification",
1722+
"FlaxDistilBertModel",
1723+
"FlaxDistilBertPreTrainedModel",
1724+
]
1725+
)
17151726
_import_structure["models.electra"].extend(
17161727
[
17171728
"FlaxElectraForMaskedLM",
@@ -3201,6 +3212,15 @@
32013212
FlaxCLIPVisionModel,
32023213
FlaxCLIPVisionPreTrainedModel,
32033214
)
3215+
from .models.distilbert import (
3216+
FlaxDistilBertForMaskedLM,
3217+
FlaxDistilBertForMultipleChoice,
3218+
FlaxDistilBertForQuestionAnswering,
3219+
FlaxDistilBertForSequenceClassification,
3220+
FlaxDistilBertForTokenClassification,
3221+
FlaxDistilBertModel,
3222+
FlaxDistilBertPreTrainedModel,
3223+
)
32043224
from .models.electra import (
32053225
FlaxElectraForMaskedLM,
32063226
FlaxElectraForMultipleChoice,

src/transformers/models/auto/modeling_flax_auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
FLAX_MODEL_MAPPING_NAMES = OrderedDict(
2929
[
3030
# Base model mapping
31+
("distilbert", "FlaxDistilBertModel"),
3132
("roberta", "FlaxRobertaModel"),
3233
("bert", "FlaxBertModel"),
3334
("big_bird", "FlaxBigBirdModel"),
@@ -63,6 +64,7 @@
6364
FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
6465
[
6566
# Model for Masked LM mapping
67+
("distilbert", "FlaxDistilBertForMaskedLM"),
6668
("roberta", "FlaxRobertaForMaskedLM"),
6769
("bert", "FlaxBertForMaskedLM"),
6870
("big_bird", "FlaxBigBirdForMaskedLM"),
@@ -101,6 +103,7 @@
101103
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
102104
[
103105
# Model for Sequence Classification mapping
106+
("distilbert", "FlaxDistilBertForSequenceClassification"),
104107
("roberta", "FlaxRobertaForSequenceClassification"),
105108
("bert", "FlaxBertForSequenceClassification"),
106109
("big_bird", "FlaxBigBirdForSequenceClassification"),
@@ -113,6 +116,7 @@
113116
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
114117
[
115118
# Model for Question Answering mapping
119+
("distilbert", "FlaxDistilBertForQuestionAnswering"),
116120
("roberta", "FlaxRobertaForQuestionAnswering"),
117121
("bert", "FlaxBertForQuestionAnswering"),
118122
("big_bird", "FlaxBigBirdForQuestionAnswering"),
@@ -125,6 +129,7 @@
125129
FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
126130
[
127131
# Model for Token Classification mapping
132+
("distilbert", "FlaxDistilBertForTokenClassification"),
128133
("roberta", "FlaxRobertaForTokenClassification"),
129134
("bert", "FlaxBertForTokenClassification"),
130135
("big_bird", "FlaxBigBirdForTokenClassification"),
@@ -135,6 +140,7 @@
135140
FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
136141
[
137142
# Model for Multiple Choice mapping
143+
("distilbert", "FlaxDistilBertForMultipleChoice"),
138144
("roberta", "FlaxRobertaForMultipleChoice"),
139145
("bert", "FlaxBertForMultipleChoice"),
140146
("big_bird", "FlaxBigBirdForMultipleChoice"),

src/transformers/models/distilbert/__init__.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import TYPE_CHECKING
2020

21-
from ...file_utils import _LazyModule, is_tf_available, is_tokenizers_available, is_torch_available
21+
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_tokenizers_available, is_torch_available
2222

2323

2424
_import_structure = {
@@ -58,6 +58,17 @@
5858
"TFDistilBertPreTrainedModel",
5959
]
6060

61+
if is_flax_available():
62+
_import_structure["modeling_flax_distilbert"] = [
63+
"FlaxDistilBertForMaskedLM",
64+
"FlaxDistilBertForMultipleChoice",
65+
"FlaxDistilBertForQuestionAnswering",
66+
"FlaxDistilBertForSequenceClassification",
67+
"FlaxDistilBertForTokenClassification",
68+
"FlaxDistilBertModel",
69+
"FlaxDistilBertPreTrainedModel",
70+
]
71+
6172

6273
if TYPE_CHECKING:
6374
from .configuration_distilbert import (
@@ -95,6 +106,17 @@
95106
TFDistilBertPreTrainedModel,
96107
)
97108

109+
if is_flax_available():
110+
from .modeling_flax_distilbert import (
111+
FlaxDistilBertForMaskedLM,
112+
FlaxDistilBertForMultipleChoice,
113+
FlaxDistilBertForQuestionAnswering,
114+
FlaxDistilBertForSequenceClassification,
115+
FlaxDistilBertForTokenClassification,
116+
FlaxDistilBertModel,
117+
FlaxDistilBertPreTrainedModel,
118+
)
119+
98120
else:
99121
import sys
100122

0 commit comments

Comments
 (0)