Skip to content

Commit 63645b3

Browse files
kssteven418LysandreJiksguggerpatrickvonplaten
authored
I-BERT model support (huggingface#10153)
* IBertConfig, IBertTokentizer added * IBert Model names moified * tokenizer bugfix * embedding -> QuantEmbedding * quant utils added * quant_mode added to configuration * QuantAct added, Embedding layer + QuantAct addition * QuantAct added * unused path removed, QKV quantized * self attention layer all quantized, except softmax * temporarl commit * all liner layers quantized * quant_utils bugfix * bugfix: requantization missing * IntGELU added * IntSoftmax added * LayerNorm implemented * LayerNorm implemented all * names changed: roberta->ibert * config not inherit from ROberta * No support for CausalLM * static quantization added, quantize_model.py removed * import modules uncommented * copyrights fixed * minor bugfix * quant_modules, quant_utils merged as one file * import * fixed * unused runfile removed * make style run * configutration.py docstring fixed * refactoring: comments removed, function name fixed * unused dependency removed * typo fixed * comments(Copied from), assertion string added * refactoring: super(..) -> super(), etc. * refactoring * refarctoring * make style * refactoring * cuda -> to(x.device) * weight initialization removed * QuantLinear set_param removed * QuantEmbedding set_param removed * IntLayerNorm set_param removed * assert string added * assertion error message fixed * is_decoder removed * enc-dec arguments/functions removed * Converter removed * quant_modules docstring fixed * conver_slow_tokenizer rolled back * quant_utils docstring fixed * unused aruments e.g. use_cache removed from config * weight initialization condition fixed * x_min, x_max initialized with small values to avoid div-zero exceptions * testing code for ibert * test emb, linear, gelu, softmax added * test ln and act added * style reformatted * force_dequant added * error tests overrided * make style * Style + Docs * force dequant tests added * Fix fast tokenizer in init * Fix doc * Remove space * docstring, IBertConfig, chunk_size * test_modeling_ibert refactoring * quant_modules.py refactoring * e2e integration test added * tokenizers removed * IBertConfig added to tokenizer_auto.py * bugfix * fix docs & test * fix style num 2 * final fixes Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent cb38ffc commit 63645b3

File tree

12 files changed

+3279
-0
lines changed

12 files changed

+3279
-0
lines changed

docs/source/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ TensorFlow and/or Flax.
263263
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
264264
| Funnel Transformer ||||||
265265
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
266+
| I-BERT ||||||
267+
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
266268
| LED ||||||
267269
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
268270
| LXMERT ||||||
@@ -405,6 +407,7 @@ TensorFlow and/or Flax.
405407
model_doc/fsmt
406408
model_doc/funnel
407409
model_doc/herbert
410+
model_doc/ibert
408411
model_doc/layoutlm
409412
model_doc/led
410413
model_doc/longformer

docs/source/model_doc/ibert.rst

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
..
2+
Copyright 2020 The HuggingFace Team. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
the License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
13+
I-BERT
14+
-----------------------------------------------------------------------------------------------------------------------
15+
16+
Overview
17+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
18+
19+
The I-BERT model was proposed in `I-BERT: Integer-only BERT Quantization <https://arxiv.org/abs/2006.10220>`__ by
20+
Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney and Kurt Keutzer. It's a quantized version of RoBERTa running
21+
inference up to four times faster.
22+
23+
The abstract from the paper is the following:
24+
25+
*Transformer based models, like BERT and RoBERTa, have achieved state-of-the-art results in many Natural Language
26+
Processing tasks. However, their memory footprint, inference latency, and power consumption are prohibitive for
27+
efficient inference at the edge, and even at the data center. While quantization can be a viable solution for this,
28+
previous work on quantizing Transformer based models use floating-point arithmetic during inference, which cannot
29+
efficiently utilize integer-only logical units such as the recent Turing Tensor Cores, or traditional integer-only ARM
30+
processors. In this work, we propose I-BERT, a novel quantization scheme for Transformer based models that quantizes
31+
the entire inference with integer-only arithmetic. Based on lightweight integer-only approximation methods for
32+
nonlinear operations, e.g., GELU, Softmax, and Layer Normalization, I-BERT performs an end-to-end integer-only BERT
33+
inference without any floating point calculation. We evaluate our approach on GLUE downstream tasks using
34+
RoBERTa-Base/Large. We show that for both cases, I-BERT achieves similar (and slightly higher) accuracy as compared to
35+
the full-precision baseline. Furthermore, our preliminary implementation of I-BERT shows a speedup of 2.4 - 4.0x for
36+
INT8 inference on a T4 GPU system as compared to FP32 inference. The framework has been developed in PyTorch and has
37+
been open-sourced.*
38+
39+
40+
The original code can be found `here <https://github.com/kssteven418/I-BERT>`__.
41+
42+
IBertConfig
43+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
44+
45+
.. autoclass:: transformers.IBertConfig
46+
:members:
47+
48+
49+
IBertModel
50+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51+
52+
.. autoclass:: transformers.IBertModel
53+
:members: forward
54+
55+
56+
IBertForMaskedLM
57+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
58+
59+
.. autoclass:: transformers.IBertForMaskedLM
60+
:members: forward
61+
62+
63+
IBertForSequenceClassification
64+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
65+
66+
.. autoclass:: transformers.IBertForSequenceClassification
67+
:members: forward
68+
69+
70+
IBertForMultipleChoice
71+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
72+
73+
.. autoclass:: transformers.IBertForMultipleChoice
74+
:members: forward
75+
76+
77+
IBertForTokenClassification
78+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
79+
80+
.. autoclass:: transformers.IBertForTokenClassification
81+
:members: forward
82+
83+
84+
IBertForQuestionAnswering
85+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86+
87+
.. autoclass:: transformers.IBertForQuestionAnswering
88+
:members: forward

src/transformers/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
183183
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
184184
"models.herbert": ["HerbertTokenizer"],
185+
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
185186
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"],
186187
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
187188
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
@@ -613,6 +614,20 @@
613614
"load_tf_weights_in_gpt2",
614615
]
615616
)
617+
_import_structure["models.ibert"].extend(
618+
[
619+
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
620+
"IBertForMaskedLM",
621+
"IBertForMultipleChoice",
622+
"IBertForQuestionAnswering",
623+
"IBertForSequenceClassification",
624+
"IBertForTokenClassification",
625+
"IBertLayer",
626+
"IBertModel",
627+
"IBertPreTrainedModel",
628+
"load_tf_weights_in_ibert",
629+
]
630+
)
616631
_import_structure["models.layoutlm"].extend(
617632
[
618633
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1328,6 +1343,7 @@
13281343
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
13291344
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
13301345
from .models.herbert import HerbertTokenizer
1346+
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
13311347
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
13321348
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
13331349
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
@@ -1710,6 +1726,15 @@
17101726
GPT2PreTrainedModel,
17111727
load_tf_weights_in_gpt2,
17121728
)
1729+
from .models.ibert import (
1730+
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
1731+
IBertForMaskedLM,
1732+
IBertForMultipleChoice,
1733+
IBertForQuestionAnswering,
1734+
IBertForSequenceClassification,
1735+
IBertForTokenClassification,
1736+
IBertModel,
1737+
)
17131738
from .models.layoutlm import (
17141739
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
17151740
LayoutLMForMaskedLM,

src/transformers/models/auto/configuration_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
4141
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
4242
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
43+
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
4344
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
4445
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
4546
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
@@ -110,6 +111,7 @@
110111
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
111112
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
112113
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
114+
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
113115
]
114116
for key, value, in pretrained_map.items()
115117
)
@@ -123,6 +125,7 @@
123125
("led", LEDConfig),
124126
("blenderbot-small", BlenderbotSmallConfig),
125127
("retribert", RetriBertConfig),
128+
("ibert", IBertConfig),
126129
("mt5", MT5Config),
127130
("t5", T5Config),
128131
("mobilebert", MobileBertConfig),
@@ -173,6 +176,7 @@
173176
("led", "LED"),
174177
("blenderbot-small", "BlenderbotSmall"),
175178
("retribert", "RetriBERT"),
179+
("ibert", "I-BERT"),
176180
("t5", "T5"),
177181
("mobilebert", "MobileBERT"),
178182
("distilbert", "DistilBERT"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,14 @@
129129
FunnelModel,
130130
)
131131
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
132+
from ..ibert.modeling_ibert import (
133+
IBertForMaskedLM,
134+
IBertForMultipleChoice,
135+
IBertForQuestionAnswering,
136+
IBertForSequenceClassification,
137+
IBertForTokenClassification,
138+
IBertModel,
139+
)
132140
from ..layoutlm.modeling_layoutlm import (
133141
LayoutLMForMaskedLM,
134142
LayoutLMForSequenceClassification,
@@ -270,6 +278,7 @@
270278
FSMTConfig,
271279
FunnelConfig,
272280
GPT2Config,
281+
IBertConfig,
273282
LayoutLMConfig,
274283
LEDConfig,
275284
LongformerConfig,
@@ -347,6 +356,7 @@
347356
(MPNetConfig, MPNetModel),
348357
(TapasConfig, TapasModel),
349358
(MarianConfig, MarianModel),
359+
(IBertConfig, IBertModel),
350360
]
351361
)
352362

@@ -379,6 +389,7 @@
379389
(FunnelConfig, FunnelForPreTraining),
380390
(MPNetConfig, MPNetForMaskedLM),
381391
(TapasConfig, TapasForMaskedLM),
392+
(IBertConfig, IBertForMaskedLM),
382393
]
383394
)
384395

@@ -418,6 +429,7 @@
418429
(TapasConfig, TapasForMaskedLM),
419430
(DebertaConfig, DebertaForMaskedLM),
420431
(DebertaV2Config, DebertaV2ForMaskedLM),
432+
(IBertConfig, IBertForMaskedLM),
421433
]
422434
)
423435

@@ -476,6 +488,7 @@
476488
(TapasConfig, TapasForMaskedLM),
477489
(DebertaConfig, DebertaForMaskedLM),
478490
(DebertaV2Config, DebertaV2ForMaskedLM),
491+
(IBertConfig, IBertForMaskedLM),
479492
]
480493
)
481494

@@ -529,6 +542,7 @@
529542
(TransfoXLConfig, TransfoXLForSequenceClassification),
530543
(MPNetConfig, MPNetForSequenceClassification),
531544
(TapasConfig, TapasForSequenceClassification),
545+
(IBertConfig, IBertForSequenceClassification),
532546
]
533547
)
534548

@@ -558,6 +572,7 @@
558572
(MPNetConfig, MPNetForQuestionAnswering),
559573
(DebertaConfig, DebertaForQuestionAnswering),
560574
(DebertaV2Config, DebertaV2ForQuestionAnswering),
575+
(IBertConfig, IBertForQuestionAnswering),
561576
]
562577
)
563578

@@ -591,6 +606,7 @@
591606
(MPNetConfig, MPNetForTokenClassification),
592607
(DebertaConfig, DebertaForTokenClassification),
593608
(DebertaV2Config, DebertaV2ForTokenClassification),
609+
(IBertConfig, IBertForTokenClassification),
594610
]
595611
)
596612

@@ -613,6 +629,7 @@
613629
(FlaubertConfig, FlaubertForMultipleChoice),
614630
(FunnelConfig, FunnelForMultipleChoice),
615631
(MPNetConfig, MPNetForMultipleChoice),
632+
(IBertConfig, IBertForMultipleChoice),
616633
]
617634
)
618635

src/transformers/models/auto/tokenization_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
FSMTConfig,
7676
FunnelConfig,
7777
GPT2Config,
78+
IBertConfig,
7879
LayoutLMConfig,
7980
LEDConfig,
8081
LongformerConfig,
@@ -244,6 +245,7 @@
244245
(TapasConfig, (TapasTokenizer, None)),
245246
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
246247
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
248+
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
247249
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
248250
]
249251
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# flake8: noqa
2+
# There's no way to ignore "F401 '...' imported but unused" warnings in this
3+
# module, but to preserve other warnings. So, don't check this module at all.
4+
5+
# Copyright 2020 The HuggingFace Team. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
from typing import TYPE_CHECKING
20+
21+
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
22+
23+
24+
_import_structure = {
25+
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
26+
}
27+
28+
if is_torch_available():
29+
_import_structure["modeling_ibert"] = [
30+
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
31+
"IBertForMaskedLM",
32+
"IBertForMultipleChoice",
33+
"IBertForQuestionAnswering",
34+
"IBertForSequenceClassification",
35+
"IBertForTokenClassification",
36+
"IBertModel",
37+
]
38+
39+
if TYPE_CHECKING:
40+
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
41+
42+
if is_torch_available():
43+
from .modeling_ibert import (
44+
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
45+
IBertForMaskedLM,
46+
IBertForMultipleChoice,
47+
IBertForQuestionAnswering,
48+
IBertForSequenceClassification,
49+
IBertForTokenClassification,
50+
IBertModel,
51+
)
52+
53+
else:
54+
import importlib
55+
import os
56+
import sys
57+
58+
class _LazyModule(_BaseLazyModule):
59+
"""
60+
Module class that surfaces all objects but only performs associated imports when the objects are requested.
61+
"""
62+
63+
__file__ = globals()["__file__"]
64+
__path__ = [os.path.dirname(__file__)]
65+
66+
def _get_module(self, module_name: str):
67+
return importlib.import_module("." + module_name, self.__name__)
68+
69+
sys.modules[__name__] = _LazyModule(__name__, _import_structure)

0 commit comments

Comments
 (0)