Skip to content

Commit 32954c3

Browse files
richardpaulhudsonDaniël de Kok
andauthored
Fix issues for Mypy 0.950 and Pydantic 1.9.0 (explosion#10786)
* Make changes to typing * Correction * Format with black * Corrections based on review * Bumped Thinc dependency version * Bumped blis requirement * Correction for older Python versions * Update spacy/ml/models/textcat.py Co-authored-by: Daniël de Kok <me@github.danieldk.eu> * Corrections based on review feedback * Readd deleted docstring line Co-authored-by: Daniël de Kok <me@github.danieldk.eu>
1 parent 6be09bb commit 32954c3

16 files changed

+63
-61
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ requires = [
55
"cymem>=2.0.2,<2.1.0",
66
"preshed>=3.0.2,<3.1.0",
77
"murmurhash>=0.28.0,<1.1.0",
8-
"thinc>=8.0.14,<8.1.0",
9-
"blis>=0.4.0,<0.8.0",
8+
"thinc>=8.1.0.dev0,<8.2.0",
9+
"blis>=0.9.0,<0.10.0",
1010
"pathy",
1111
"numpy>=1.15.0",
1212
]

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ spacy-legacy>=3.0.9,<3.1.0
33
spacy-loggers>=1.0.0,<2.0.0
44
cymem>=2.0.2,<2.1.0
55
preshed>=3.0.2,<3.1.0
6-
thinc>=8.0.14,<8.1.0
7-
blis>=0.4.0,<0.8.0
6+
thinc>=8.1.0.dev0,<8.2.0
7+
blis>=0.9.0,<0.10.0
88
ml_datasets>=0.2.0,<0.3.0
99
murmurhash>=0.28.0,<1.1.0
1010
wasabi>=0.9.1,<1.1.0
@@ -16,7 +16,7 @@ pathy>=0.3.5
1616
numpy>=1.15.0
1717
requests>=2.13.0,<3.0.0
1818
tqdm>=4.38.0,<5.0.0
19-
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
19+
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
2020
jinja2
2121
langcodes>=3.2.0,<4.0.0
2222
# Official Python utilities
@@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
3131
mock>=2.0.0,<3.0.0
3232
flake8>=3.8.0,<3.10.0
3333
hypothesis>=3.27.0,<7.0.0
34-
mypy==0.910
34+
mypy>=0.910,<=0.960
3535
types-dataclasses>=0.1.3; python_version < "3.7"
3636
types-mock>=0.1.1
3737
types-requests

setup.cfg

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@ setup_requires =
3838
cymem>=2.0.2,<2.1.0
3939
preshed>=3.0.2,<3.1.0
4040
murmurhash>=0.28.0,<1.1.0
41-
thinc>=8.0.14,<8.1.0
41+
thinc>=8.1.0.dev0,<8.2.0
4242
install_requires =
4343
# Our libraries
4444
spacy-legacy>=3.0.9,<3.1.0
4545
spacy-loggers>=1.0.0,<2.0.0
4646
murmurhash>=0.28.0,<1.1.0
4747
cymem>=2.0.2,<2.1.0
4848
preshed>=3.0.2,<3.1.0
49-
thinc>=8.0.14,<8.1.0
50-
blis>=0.4.0,<0.8.0
49+
thinc>=8.1.0.dev0,<8.2.0
50+
blis>=0.9.0,<0.10.0
5151
wasabi>=0.9.1,<1.1.0
5252
srsly>=2.4.3,<3.0.0
5353
catalogue>=2.0.6,<2.1.0
@@ -57,7 +57,7 @@ install_requires =
5757
tqdm>=4.38.0,<5.0.0
5858
numpy>=1.15.0
5959
requests>=2.13.0,<3.0.0
60-
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.9.0
60+
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.10.0
6161
jinja2
6262
# Official Python utilities
6363
setuptools

spacy/errors.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from .compat import Literal
23

34

45
class ErrorsWithCodes(type):
@@ -26,7 +27,10 @@ def setup_default_warnings():
2627
filter_warning("once", error_msg="[W114]")
2728

2829

29-
def filter_warning(action: str, error_msg: str):
30+
def filter_warning(
31+
action: Literal["default", "error", "ignore", "always", "module", "once"],
32+
error_msg: str,
33+
):
3034
"""Customize how spaCy should handle a certain warning.
3135
3236
error_msg (str): e.g. "W006", or a full error message

spacy/lookups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __setitem__(self, key: Union[str, int], value: Any) -> None:
8585
value: The value to set.
8686
"""
8787
key = get_string_id(key)
88-
OrderedDict.__setitem__(self, key, value)
88+
OrderedDict.__setitem__(self, key, value) # type:ignore[assignment]
8989
self.bloom.add(key)
9090

9191
def set(self, key: Union[str, int], value: Any) -> None:
@@ -104,7 +104,7 @@ def __getitem__(self, key: Union[str, int]) -> Any:
104104
RETURNS: The value.
105105
"""
106106
key = get_string_id(key)
107-
return OrderedDict.__getitem__(self, key)
107+
return OrderedDict.__getitem__(self, key) # type:ignore[index]
108108

109109
def get(self, key: Union[str, int], default: Optional[Any] = None) -> Any:
110110
"""Get the value for a given key. String keys will be hashed.
@@ -114,7 +114,7 @@ def get(self, key: Union[str, int], default: Optional[Any] = None) -> Any:
114114
RETURNS: The value.
115115
"""
116116
key = get_string_id(key)
117-
return OrderedDict.get(self, key, default)
117+
return OrderedDict.get(self, key, default) # type:ignore[arg-type]
118118

119119
def __contains__(self, key: Union[str, int]) -> bool: # type: ignore[override]
120120
"""Check whether a key is in the table. String keys will be hashed.

spacy/ml/models/entity_linker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def build_nel_encoder(
2323
((tok2vec >> list2ragged()) & build_span_maker())
2424
>> extract_spans()
2525
>> reduce_mean()
26-
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore[arg-type]
26+
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
2727
>> output_layer
2828
)
2929
model.set_ref("output_layer", output_layer)

spacy/ml/models/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def build_tb_parser_model(
7272
t2v_width = tok2vec.get_dim("nO") if tok2vec.has_dim("nO") else None
7373
tok2vec = chain(
7474
tok2vec,
75-
cast(Model[List["Floats2d"], Floats2d], list2array()),
75+
list2array(),
7676
Linear(hidden_width, t2v_width),
7777
)
7878
tok2vec.set_dim("nO", hidden_width)

spacy/ml/models/textcat.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
from typing import Optional, List, cast
12
from functools import partial
2-
from typing import Optional, List
33

44
from thinc.types import Floats2d
55
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
@@ -59,7 +59,8 @@ def build_simple_cnn_text_classifier(
5959
resizable_layer=resizable_layer,
6060
)
6161
model.set_ref("tok2vec", tok2vec)
62-
model.set_dim("nO", nO) # type: ignore # TODO: remove type ignore once Thinc has been updated
62+
if nO is not None:
63+
model.set_dim("nO", cast(int, nO))
6364
model.attrs["multi_label"] = not exclusive_classes
6465
return model
6566

@@ -85,15 +86,16 @@ def build_bow_text_classifier(
8586
if not no_output_layer:
8687
fill_defaults["b"] = NEG_VALUE
8788
output_layer = softmax_activation() if exclusive_classes else Logistic()
88-
resizable_layer = resizable( # type: ignore[var-annotated]
89+
resizable_layer: Model[Floats2d, Floats2d] = resizable(
8990
sparse_linear,
9091
resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
9192
)
9293
model = extract_ngrams(ngram_size, attr=ORTH) >> resizable_layer
9394
model = with_cpu(model, model.ops)
9495
if output_layer:
9596
model = model >> with_cpu(output_layer, output_layer.ops)
96-
model.set_dim("nO", nO) # type: ignore[arg-type]
97+
if nO is not None:
98+
model.set_dim("nO", cast(int, nO))
9799
model.set_ref("output_layer", sparse_linear)
98100
model.attrs["multi_label"] = not exclusive_classes
99101
model.attrs["resize_output"] = partial(
@@ -129,8 +131,8 @@ def build_text_classifier_v2(
129131
output_layer = Linear(nO=nO, nI=nO_double) >> Logistic()
130132
model = (linear_model | cnn_model) >> output_layer
131133
model.set_ref("tok2vec", tok2vec)
132-
if model.has_dim("nO") is not False:
133-
model.set_dim("nO", nO) # type: ignore[arg-type]
134+
if model.has_dim("nO") is not False and nO is not None:
135+
model.set_dim("nO", cast(int, nO))
134136
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
135137
model.set_ref("attention_layer", attention_layer)
136138
model.set_ref("maxout_layer", maxout_layer)
@@ -164,7 +166,7 @@ def build_text_classifier_lowdata(
164166
>> list2ragged()
165167
>> ParametricAttention(width)
166168
>> reduce_sum()
167-
>> residual(Relu(width, width)) ** 2 # type: ignore[arg-type]
169+
>> residual(Relu(width, width)) ** 2
168170
>> Linear(nO, width)
169171
)
170172
if dropout:

spacy/ml/models/tok2vec.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional, List, Union, cast
2-
from thinc.types import Floats2d, Ints2d, Ragged
2+
from thinc.types import Floats2d, Ints2d, Ragged, Ints1d
33
from thinc.api import chain, clone, concatenate, with_array, with_padded
44
from thinc.api import Model, noop, list2ragged, ragged2list, HashEmbed
55
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
@@ -159,7 +159,7 @@ def make_hash_embed(index):
159159
embeddings = [make_hash_embed(i) for i in range(len(attrs))]
160160
concat_size = width * (len(embeddings) + include_static_vectors)
161161
max_out: Model[Ragged, Ragged] = with_array(
162-
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True) # type: ignore
162+
Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)
163163
)
164164
if include_static_vectors:
165165
feature_extractor: Model[List[Doc], Ragged] = chain(
@@ -173,15 +173,15 @@ def make_hash_embed(index):
173173
StaticVectors(width, dropout=0.0),
174174
),
175175
max_out,
176-
cast(Model[Ragged, List[Floats2d]], ragged2list()),
176+
ragged2list(),
177177
)
178178
else:
179179
model = chain(
180180
FeatureExtractor(list(attrs)),
181181
cast(Model[List[Ints2d], Ragged], list2ragged()),
182182
with_array(concatenate(*embeddings)),
183183
max_out,
184-
cast(Model[Ragged, List[Floats2d]], ragged2list()),
184+
ragged2list(),
185185
)
186186
return model
187187

@@ -232,12 +232,12 @@ def CharacterEmbed(
232232
feature_extractor: Model[List[Doc], Ragged] = chain(
233233
FeatureExtractor([feature]),
234234
cast(Model[List[Ints2d], Ragged], list2ragged()),
235-
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), # type: ignore
235+
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5)), # type: ignore[misc]
236236
)
237237
max_out: Model[Ragged, Ragged]
238238
if include_static_vectors:
239239
max_out = with_array(
240-
Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0) # type: ignore
240+
Maxout(width, nM * nC + (2 * width), nP=3, normalize=True, dropout=0.0)
241241
)
242242
model = chain(
243243
concatenate(
@@ -246,19 +246,19 @@ def CharacterEmbed(
246246
StaticVectors(width, dropout=0.0),
247247
),
248248
max_out,
249-
cast(Model[Ragged, List[Floats2d]], ragged2list()),
249+
ragged2list(),
250250
)
251251
else:
252252
max_out = with_array(
253-
Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0) # type: ignore
253+
Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)
254254
)
255255
model = chain(
256256
concatenate(
257257
char_embed,
258258
feature_extractor,
259259
),
260260
max_out,
261-
cast(Model[Ragged, List[Floats2d]], ragged2list()),
261+
ragged2list(),
262262
)
263263
return model
264264

@@ -289,10 +289,10 @@ def MaxoutWindowEncoder(
289289
normalize=True,
290290
),
291291
)
292-
model = clone(residual(cnn), depth) # type: ignore[arg-type]
292+
model = clone(residual(cnn), depth)
293293
model.set_dim("nO", width)
294294
receptive_field = window_size * depth
295-
return with_array(model, pad=receptive_field) # type: ignore[arg-type]
295+
return with_array(model, pad=receptive_field)
296296

297297

298298
@registry.architectures("spacy.MishWindowEncoder.v2")
@@ -313,9 +313,9 @@ def MishWindowEncoder(
313313
expand_window(window_size=window_size),
314314
Mish(nO=width, nI=width * ((window_size * 2) + 1), dropout=0.0, normalize=True),
315315
)
316-
model = clone(residual(cnn), depth) # type: ignore[arg-type]
316+
model = clone(residual(cnn), depth)
317317
model.set_dim("nO", width)
318-
return with_array(model) # type: ignore[arg-type]
318+
return with_array(model)
319319

320320

321321
@registry.architectures("spacy.TorchBiLSTMEncoder.v1")

spacy/ml/staticvectors.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,15 @@ def forward(
4040
if not token_count:
4141
return _handle_empty(model.ops, model.get_dim("nO"))
4242
key_attr: int = model.attrs["key_attr"]
43-
keys: Ints1d = model.ops.flatten(
44-
cast(Sequence, [doc.to_array(key_attr) for doc in docs])
45-
)
43+
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
4644
vocab: Vocab = docs[0].vocab
4745
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
4846
if vocab.vectors.mode == Mode.default:
49-
V = cast(Floats2d, model.ops.asarray(vocab.vectors.data))
47+
V = model.ops.asarray(vocab.vectors.data)
5048
rows = vocab.vectors.find(keys=keys)
5149
V = model.ops.as_contig(V[rows])
5250
elif vocab.vectors.mode == Mode.floret:
53-
V = cast(Floats2d, vocab.vectors.get_batch(keys))
51+
V = vocab.vectors.get_batch(keys)
5452
V = model.ops.as_contig(V)
5553
else:
5654
raise RuntimeError(Errors.E896)
@@ -62,9 +60,7 @@ def forward(
6260
# Convert negative indices to 0-vectors
6361
# TODO: more options for UNK tokens
6462
vectors_data[rows < 0] = 0
65-
output = Ragged(
66-
vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") # type: ignore
67-
)
63+
output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
6864
mask = None
6965
if is_train:
7066
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
@@ -77,7 +73,9 @@ def backprop(d_output: Ragged) -> List[Doc]:
7773
model.inc_grad(
7874
"W",
7975
model.ops.gemm(
80-
cast(Floats2d, d_output.data), model.ops.as_contig(V), trans1=True
76+
cast(Floats2d, d_output.data),
77+
cast(Floats2d, model.ops.as_contig(V)),
78+
trans1=True,
8179
),
8280
)
8381
return []

spacy/pipeline/edit_tree_lemmatizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def get_loss(
138138

139139
truths.append(eg_truths)
140140

141-
d_scores, loss = loss_func(scores, truths) # type: ignore
141+
d_scores, loss = loss_func(scores, truths)
142142
if self.model.ops.xp.isnan(loss):
143143
raise ValueError(Errors.E910.format(name=self.name))
144144

spacy/pipeline/entityruler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,8 @@ def match(self, doc: Doc):
159159
self._require_patterns()
160160
with warnings.catch_warnings():
161161
warnings.filterwarnings("ignore", message="\\[W036")
162-
matches = cast(
163-
List[Tuple[int, int, int]],
164-
list(self.matcher(doc)) + list(self.phrase_matcher(doc)),
165-
)
162+
matches = list(self.matcher(doc)) + list(self.phrase_matcher(doc))
163+
166164
final_matches = set(
167165
[(m_id, start, end) for m_id, start, end in matches if start != end]
168166
)

spacy/pipeline/legacy/entity_linker.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,14 @@ def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
213213
if kb_id:
214214
entity_encoding = self.kb.get_vector(kb_id)
215215
entity_encodings.append(entity_encoding)
216-
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
216+
entity_encodings = self.model.ops.asarray2f(entity_encodings)
217217
if sentence_encodings.shape != entity_encodings.shape:
218218
err = Errors.E147.format(
219219
method="get_loss", msg="gold entities do not match up"
220220
)
221221
raise RuntimeError(err)
222-
# TODO: fix typing issue here
223-
gradients = self.distance.get_grad(sentence_encodings, entity_encodings) # type: ignore
224-
loss = self.distance.get_loss(sentence_encodings, entity_encodings) # type: ignore
222+
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
223+
loss = self.distance.get_loss(sentence_encodings, entity_encodings)
225224
loss = loss / len(entity_encodings)
226225
return float(loss), gradients
227226

spacy/pipeline/spancat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def ngram_suggester(docs: Iterable[Doc], *, ops: Optional[Ops] = None) -> Ragged
7575
if spans:
7676
assert spans[-1].ndim == 2, spans[-1].shape
7777
lengths.append(length)
78-
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i"))
78+
lengths_array = ops.asarray1i(lengths)
7979
if len(spans) > 0:
8080
output = Ragged(ops.xp.vstack(spans), lengths_array)
8181
else:

spacy/schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_arg_model(
104104
sig_args[param.name] = (annotation, default)
105105
is_strict = strict and not has_variable
106106
sig_args["__config__"] = ArgSchemaConfig if is_strict else ArgSchemaConfigExtra # type: ignore[assignment]
107-
return create_model(name, **sig_args) # type: ignore[arg-type, return-value]
107+
return create_model(name, **sig_args) # type: ignore[call-overload, arg-type, return-value]
108108

109109

110110
def validate_init_settings(

0 commit comments

Comments
 (0)