Skip to content

Commit 2fcb9dc

Browse files
committed
Last commit with on-going past commented.
1 parent 57e369e commit 2fcb9dc

File tree

4 files changed

+125
-41
lines changed

4 files changed

+125
-41
lines changed

src/transformers/models/gpt2/configuration_gpt2.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" OpenAI GPT-2 configuration """
17-
from typing import Mapping
17+
from collections import OrderedDict
18+
from typing import Any, Mapping, Optional
19+
20+
from transformers import PreTrainedTokenizer, TensorType, is_torch_available
1821

1922
from ...configuration_utils import PretrainedConfig
2023
from ...onnx import OnnxConfigWithPast
@@ -202,19 +205,62 @@ def num_hidden_layers(self):
202205
class GPT2OnnxConfig(OnnxConfigWithPast):
203206
@property
204207
def inputs(self) -> Mapping[str, Mapping[int, str]]:
205-
return {
206-
"input_ids": {0: "batch", 1: "sequence"},
207-
"attention_mask": {0: "batch", 1: "sequence"},
208-
}
208+
if self.use_past:
209+
common_inputs = OrderedDict({"input_ids": {0: "batch"}})
210+
for i in range(self._config.n_layer * 2):
211+
common_inputs[f"past_key_values.{i}"] = {0: "batch", 2: "sequence"}
212+
213+
common_inputs["attention_mask"] = {0: "batch"}
214+
else:
215+
common_inputs = OrderedDict({
216+
"input_ids": {0: "batch", 1: "sequence"},
217+
"attention_mask": {0: "batch", 1: "sequence"}
218+
})
219+
220+
return common_inputs
209221

210222
@property
211223
def outputs(self) -> Mapping[str, Mapping[int, str]]:
212224
if self.use_past:
213-
return {
214-
"last_hidden_state": {0: "batch", 1: "sequence"},
215-
"past_keys": {0: "batch", 2: "sequence"},
216-
}
225+
common_outputs = {"last_hidden_state": {0: "batch", 1: "sequence"}}
226+
227+
for i in range(self._config.n_layer * 2):
228+
common_outputs[f"present.{i}"] = {0: "batch", 2: "sequence"}
229+
230+
return common_outputs
217231
else:
218-
return {
219-
"last_hidden_state": {0: "batch", 1: "sequence"},
220-
}
232+
return {"last_hidden_state": {0: "batch", 1: "sequence"}}
233+
234+
def generate_dummy_inputs(
235+
self,
236+
tokenizer: PreTrainedTokenizer,
237+
batch_size: int = -1,
238+
seq_length: int = -1,
239+
is_pair: bool = False,
240+
framework: Optional[TensorType] = None,
241+
) -> Mapping[str, Any]:
242+
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
243+
244+
# We need to order the input in the way they appears in the forward()
245+
ordered_inputs = OrderedDict({
246+
"input_ids": common_inputs["input_ids"]
247+
})
248+
249+
# Need to add the past_keys
250+
if self.use_past:
251+
if not is_torch_available():
252+
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
253+
else:
254+
import torch
255+
256+
batch = common_inputs["input_ids"].shape[0]
257+
ordered_inputs["past_key_values"] = [
258+
(
259+
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
260+
torch.zeros((batch, self._config.n_head, 1, self._config.hidden_size // self._config.n_head)),
261+
)
262+
for _ in range(self._config.n_layer)
263+
]
264+
265+
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
266+
return ordered_inputs

src/transformers/onnx/config.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class OnnxConfig(ABC):
3030
Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
3131
"""
3232

33+
DEFAULT_FIXED_BATCH = 2
34+
DEFAULT_FIXED_SEQUENCE = 8
35+
3336
def __init__(self, config: PretrainedConfig):
3437
self._config = config
3538

@@ -131,11 +134,15 @@ def generate_dummy_inputs(
131134
"""
132135

133136
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
134-
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=2, num_token_to_add=0)
137+
batch_size = compute_effective_axis_dimension(
138+
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
139+
)
135140

136141
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
137142
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
138-
seq_length = compute_effective_axis_dimension(seq_length, fixed_dimension=8, num_token_to_add=token_to_add)
143+
seq_length = compute_effective_axis_dimension(
144+
seq_length, fixed_dimension=OnnxConfig.DEFAULT_FIXED_SEQUENCE, num_token_to_add=token_to_add
145+
)
139146

140147
# Generate dummy inputs according to compute batch and sequence
141148
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
@@ -166,3 +173,29 @@ def values_override(self) -> Optional[Mapping[str, Any]]:
166173
return {"use_cache": self.use_past}
167174

168175
return None
176+
177+
def generate_dummy_inputs(
178+
self,
179+
tokenizer: PreTrainedTokenizer,
180+
batch_size: int = -1,
181+
seq_length: int = -1,
182+
is_pair: bool = False,
183+
framework: Optional[TensorType] = None,
184+
) -> Mapping[str, Any]:
185+
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
186+
batch_size = compute_effective_axis_dimension(
187+
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
188+
)
189+
190+
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
191+
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
192+
193+
# When use_past the caching mechanism requires inputs to be only 1 single token
194+
fixed_sequence_length = 1 if self.use_past else OnnxConfig.DEFAULT_FIXED_SEQUENCE
195+
seq_length = compute_effective_axis_dimension(
196+
seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
197+
)
198+
199+
# Generate dummy inputs according to compute batch and sequence
200+
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
201+
return dict(tokenizer(dummy_input, return_tensors=framework))

src/transformers/onnx/convert.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from inspect import signature
1516
from itertools import chain
1617
from pathlib import Path
1718
from typing import Iterable, List, Tuple, Union
@@ -97,7 +98,7 @@ def convert_pytorch(
9798
# Ensure inputs match
9899
# TODO: Check when exporting QA we provide "is_pair=True"
99100
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
100-
inputs_match, ordered_onnx_inputs = ensure_model_and_config_inputs_match(model_inputs.keys(), config.inputs.keys())
101+
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
101102
onnx_outputs = list(config.outputs.keys())
102103

103104
if not inputs_match:
@@ -108,7 +109,7 @@ def convert_pytorch(
108109
model,
109110
(model_inputs,),
110111
f=output.as_posix(),
111-
input_names=ordered_onnx_inputs,
112+
input_names=list(config.inputs.keys()),
112113
output_names=onnx_outputs,
113114
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
114115
do_constant_folding=True,
@@ -117,7 +118,7 @@ def convert_pytorch(
117118
opset_version=opset,
118119
)
119120

120-
return ordered_onnx_inputs, onnx_outputs
121+
return matched_inputs, onnx_outputs
121122

122123

123124
def validate_model_outputs(
@@ -133,7 +134,6 @@ def validate_model_outputs(
133134
logger.info("Validating ONNX model...")
134135

135136
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
136-
onnx_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.NUMPY)
137137

138138
# Create ONNX Runtime session
139139
options = SessionOptions()
@@ -151,8 +151,17 @@ def validate_model_outputs(
151151
else:
152152
ref_outputs_dict[name] = value
153153

154+
# We flatten potential collection of inputs (i.e. past_keys)
155+
onnx_inputs = {}
156+
for name, value in reference_model_inputs.items():
157+
if isinstance(value, (list, tuple)):
158+
value = flatten_output_collection_property(name, value)
159+
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
160+
else:
161+
onnx_inputs[name] = value.numpy()
162+
154163
# Compute outputs from the ONNX model
155-
onnx_outputs = session.run(onnx_named_outputs, dict(onnx_model_inputs))
164+
onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
156165

157166
# Check we have a subset of the keys into onnx_outputs against ref_outputs
158167
ref_outputs_set, onnx_outputs_set = set(ref_outputs_dict.keys()), set(onnx_named_outputs)
@@ -195,20 +204,22 @@ def validate_model_outputs(
195204

196205

197206
def ensure_model_and_config_inputs_match(
198-
model_inputs: Iterable[str], config_inputs: Iterable[str]
207+
model: Union[PreTrainedModel, TFPreTrainedModel], model_inputs: Iterable[str]
199208
) -> Tuple[bool, List[str]]:
200209
"""
201210
202211
:param model_inputs:
203212
:param config_inputs:
204213
:return:
205214
"""
206-
model_inputs_set, config_inputs_set = set(model_inputs), set(config_inputs)
215+
forward_parameters = signature(model.forward).parameters
216+
model_inputs_set = set(model_inputs)
207217

208218
# We are fine if config_inputs has more keys than model_inputs
209-
is_ok = model_inputs_set.issubset(config_inputs_set)
219+
forward_inputs_set = set(forward_parameters.keys())
220+
is_ok = model_inputs_set.issubset(forward_inputs_set)
210221

211-
# Make sure the input order match
212-
matching_inputs = config_inputs_set.intersection(model_inputs_set)
213-
ordered_matching_inputs = [config_input for config_input in config_inputs if config_input in matching_inputs]
214-
return is_ok, ordered_matching_inputs
222+
# Make sure the input order match (VERY IMPORTANT !!!!)
223+
matching_inputs = forward_inputs_set.intersection(model_inputs_set)
224+
ordered_inputs = [parameter for parameter in forward_parameters.keys() if parameter in matching_inputs]
225+
return is_ok, ordered_inputs

tests/test_onnx_v2.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
121121
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
122122
"""
123123

124-
SUPPORTED_WITH_PAST_CONFIGS = {
125-
("BART", BartConfig),
126-
("GPT2", GPT2Config),
127-
("T5", T5Config)
128-
}
124+
SUPPORTED_WITH_PAST_CONFIGS = {("BART", BartConfig), ("GPT2", GPT2Config), ("T5", T5Config)}
129125

130126
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
131127
def test_use_past(self):
@@ -135,13 +131,11 @@ def test_use_past(self):
135131
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
136132
with self.subTest(name):
137133
self.assertFalse(
138-
OnnxConfigWithPast.default(config()).use_past,
139-
"OnnxConfigWithPast.default() should not use_past"
134+
OnnxConfigWithPast.default(config()).use_past, "OnnxConfigWithPast.default() should not use_past"
140135
)
141136

142137
self.assertTrue(
143-
OnnxConfigWithPast.with_past(config()).use_past,
144-
"OnnxConfigWithPast.default() should use_past"
138+
OnnxConfigWithPast.with_past(config()).use_past, "OnnxConfigWithPast.default() should use_past"
145139
)
146140

147141
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
@@ -157,17 +151,15 @@ def test_values_override(self):
157151
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
158152
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
159153
self.assertFalse(
160-
onnx_config_default.values_override["use_cache"],
161-
"use_cache should be False if not using past"
154+
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
162155
)
163156

164157
# with past
165158
onnx_config_default = OnnxConfigWithPast.with_past(config())
166159
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
167160
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
168161
self.assertTrue(
169-
onnx_config_default.values_override["use_cache"],
170-
"use_cache should be False if not using past"
162+
onnx_config_default.values_override["use_cache"], "use_cache should be False if not using past"
171163
)
172164

173165

@@ -197,6 +189,7 @@ class OnnxExportTestCaseV2(TestCase):
197189
"""
198190
Integration tests ensuring supported models are correctly exported
199191
"""
192+
200193
@slow
201194
@require_torch
202195
def test_pytorch_export_default(self):
@@ -211,8 +204,9 @@ def test_pytorch_export_default(self):
211204
onnx_config = onnx_config_class.default(model.config)
212205

213206
with NamedTemporaryFile("w") as output:
214-
onnx_inputs, onnx_outputs = \
215-
convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name))
207+
onnx_inputs, onnx_outputs = convert_pytorch(
208+
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
209+
)
216210

217211
try:
218212
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)

0 commit comments

Comments
 (0)