Skip to content

Commit 98e409a

Browse files
authored
albert flax (huggingface#13294)
* albert flax * year -> 2021 * docstring updated for flax * removed head_mask * removed from_pt * removed passing attention_mask to embedding layer
1 parent ee5b245 commit 98e409a

File tree

8 files changed

+1442
-2
lines changed

8 files changed

+1442
-2
lines changed

docs/source/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ Flax), PyTorch, and/or TensorFlow.
321321
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
322322
| Model | Tokenizer slow | Tokenizer fast | PyTorch support | TensorFlow support | Flax Support |
323323
+=============================+================+================+=================+====================+==============+
324-
| ALBERT ||||| |
324+
| ALBERT ||||| |
325325
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
326326
| BART ||||||
327327
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

docs/source/model_doc/albert.rst

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ Tips:
4343
similar to a BERT-like architecture with the same number of hidden layers as it has to iterate through the same
4444
number of (repeating) layers.
4545

46-
This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. The original code can be found `here
46+
This model was contributed by `lysandre <https://huggingface.co/lysandre>`__. This model jax version was contributed by
47+
`kamalkraj <https://huggingface.co/kamalkraj>`__. The original code can be found `here
4748
<https://github.com/google-research/ALBERT>`__.
4849

4950
AlbertConfig
@@ -174,3 +175,52 @@ TFAlbertForQuestionAnswering
174175

175176
.. autoclass:: transformers.TFAlbertForQuestionAnswering
176177
:members: call
178+
179+
180+
FlaxAlbertModel
181+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
182+
183+
.. autoclass:: transformers.FlaxAlbertModel
184+
:members: __call__
185+
186+
187+
FlaxAlbertForPreTraining
188+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
189+
190+
.. autoclass:: transformers.FlaxAlbertForPreTraining
191+
:members: __call__
192+
193+
194+
FlaxAlbertForMaskedLM
195+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196+
197+
.. autoclass:: transformers.FlaxAlbertForMaskedLM
198+
:members: __call__
199+
200+
201+
FlaxAlbertForSequenceClassification
202+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
203+
204+
.. autoclass:: transformers.FlaxAlbertForSequenceClassification
205+
:members: __call__
206+
207+
208+
FlaxAlbertForMultipleChoice
209+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
210+
211+
.. autoclass:: transformers.FlaxAlbertForMultipleChoice
212+
:members: __call__
213+
214+
215+
FlaxAlbertForTokenClassification
216+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
217+
218+
.. autoclass:: transformers.FlaxAlbertForTokenClassification
219+
:members: __call__
220+
221+
222+
FlaxAlbertForQuestionAnswering
223+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
224+
225+
.. autoclass:: transformers.FlaxAlbertForQuestionAnswering
226+
:members: __call__

src/transformers/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,18 @@
16421642
"FlaxTopPLogitsWarper",
16431643
]
16441644
_import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
1645+
_import_structure["models.albert"].extend(
1646+
[
1647+
"FlaxAlbertForMaskedLM",
1648+
"FlaxAlbertForMultipleChoice",
1649+
"FlaxAlbertForPreTraining",
1650+
"FlaxAlbertForQuestionAnswering",
1651+
"FlaxAlbertForSequenceClassification",
1652+
"FlaxAlbertForTokenClassification",
1653+
"FlaxAlbertModel",
1654+
"FlaxAlbertPreTrainedModel",
1655+
]
1656+
)
16451657
_import_structure["models.auto"].extend(
16461658
[
16471659
"FLAX_MODEL_FOR_CAUSAL_LM_MAPPING",
@@ -3152,6 +3164,16 @@
31523164
FlaxTopPLogitsWarper,
31533165
)
31543166
from .modeling_flax_utils import FlaxPreTrainedModel
3167+
from .models.albert import (
3168+
FlaxAlbertForMaskedLM,
3169+
FlaxAlbertForMultipleChoice,
3170+
FlaxAlbertForPreTraining,
3171+
FlaxAlbertForQuestionAnswering,
3172+
FlaxAlbertForSequenceClassification,
3173+
FlaxAlbertForTokenClassification,
3174+
FlaxAlbertModel,
3175+
FlaxAlbertPreTrainedModel,
3176+
)
31553177
from .models.auto import (
31563178
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
31573179
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,

src/transformers/models/albert/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ...file_utils import (
2222
_LazyModule,
23+
is_flax_available,
2324
is_sentencepiece_available,
2425
is_tf_available,
2526
is_tokenizers_available,
@@ -65,6 +66,17 @@
6566
"TFAlbertPreTrainedModel",
6667
]
6768

69+
if is_flax_available():
70+
_import_structure["modeling_flax_albert"] = [
71+
"FlaxAlbertForMaskedLM",
72+
"FlaxAlbertForMultipleChoice",
73+
"FlaxAlbertForPreTraining",
74+
"FlaxAlbertForQuestionAnswering",
75+
"FlaxAlbertForSequenceClassification",
76+
"FlaxAlbertForTokenClassification",
77+
"FlaxAlbertModel",
78+
"FlaxAlbertPreTrainedModel",
79+
]
6880

6981
if TYPE_CHECKING:
7082
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
@@ -103,6 +115,17 @@
103115
TFAlbertPreTrainedModel,
104116
)
105117

118+
if is_flax_available():
119+
from .modeling_flax_albert import (
120+
FlaxAlbertForMaskedLM,
121+
FlaxAlbertForMultipleChoice,
122+
FlaxAlbertForPreTraining,
123+
FlaxAlbertForQuestionAnswering,
124+
FlaxAlbertForSequenceClassification,
125+
FlaxAlbertForTokenClassification,
126+
FlaxAlbertModel,
127+
FlaxAlbertPreTrainedModel,
128+
)
106129
else:
107130
import sys
108131

0 commit comments

Comments
 (0)