Skip to content

Commit f71de10

Browse files
authored
Merge pull request explosion#10346 from adrianeboyd/chore/v3.0-backport-10324
Fix Tok2Vec for empty batches (explosion#10324)
2 parents 034ac0a + 5caccbd commit f71de10

File tree

4 files changed

+52
-19
lines changed

4 files changed

+52
-19
lines changed

azure-pipelines.yml

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
# defined in .flake8 and overwrites the selected codes.
2323
- job: "Validate"
2424
pool:
25-
vmImage: "ubuntu-18.04"
25+
vmImage: "ubuntu-latest"
2626
steps:
2727
- task: UsePythonVersion@0
2828
inputs:
@@ -38,41 +38,50 @@ jobs:
3838
matrix:
3939
# We're only running one platform per Python version to speed up builds
4040
Python36Linux:
41-
imageName: "ubuntu-18.04"
41+
imageName: "ubuntu-latest"
4242
python.version: "3.6"
4343
# Python36Windows:
44-
# imageName: "vs2017-win2016"
44+
# imageName: "windows-latest"
4545
# python.version: "3.6"
4646
# Python36Mac:
47-
# imageName: "macos-10.14"
47+
# imageName: "macos-latest"
4848
# python.version: "3.6"
4949
# Python37Linux:
50-
# imageName: "ubuntu-18.04"
50+
# imageName: "ubuntu-latest"
5151
# python.version: "3.7"
5252
Python37Windows:
53-
imageName: "vs2017-win2016"
53+
imageName: "windows-latest"
5454
python.version: "3.7"
5555
# Python37Mac:
56-
# imageName: "macos-10.14"
56+
# imageName: "macos-latest"
5757
# python.version: "3.7"
5858
# Python38Linux:
59-
# imageName: "ubuntu-18.04"
59+
# imageName: "ubuntu-latest"
6060
# python.version: "3.8"
6161
# Python38Windows:
62-
# imageName: "vs2017-win2016"
62+
# imageName: "windows-latest"
6363
# python.version: "3.8"
6464
Python38Mac:
65-
imageName: "macos-10.14"
65+
imageName: "macos-latest"
6666
python.version: "3.8"
6767
Python39Linux:
68-
imageName: "ubuntu-18.04"
69-
python.version: "3.9"
70-
Python39Windows:
71-
imageName: "vs2017-win2016"
72-
python.version: "3.9"
73-
Python39Mac:
74-
imageName: "macos-10.14"
68+
imageName: "ubuntu-latest"
7569
python.version: "3.9"
70+
# Python39Windows:
71+
# imageName: "windows-latest"
72+
# python.version: "3.9"
73+
# Python39Mac:
74+
# imageName: "macos-latest"
75+
# python.version: "3.9"
76+
Python310Linux:
77+
imageName: "ubuntu-latest"
78+
python.version: "3.10"
79+
Python310Windows:
80+
imageName: "windows-latest"
81+
python.version: "3.10"
82+
Python310Mac:
83+
imageName: "macos-latest"
84+
python.version: "3.10"
7685
maxParallel: 4
7786
pool:
7887
vmImage: $(imageName)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@ pytest-timeout>=1.3.0,<2.0.0
2828
mock>=2.0.0,<3.0.0
2929
flake8>=3.5.0,<3.6.0
3030
hypothesis>=3.27.0,<7.0.0
31+
mypy==0.910

spacy/pipeline/tok2vec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def predict(self, docs: Iterable[Doc]):
118118
119119
DOCS: https://spacy.io/api/tok2vec#predict
120120
"""
121+
if not any(len(doc) for doc in docs):
122+
# Handle cases where there are no tokens in any docs.
123+
width = self.model.get_dim("nO")
124+
return [self.model.ops.alloc((0, width)) for doc in docs]
121125
tokvecs = self.model.predict(docs)
122126
batch_id = Tok2VecListener.get_batch_id(docs)
123127
for listener in self.listeners:

spacy/tests/pipeline/test_tok2vec.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from thinc.api import Config, get_current_ops
1212
from numpy.testing import assert_array_equal
1313

14-
from ..util import get_batch, make_tempdir
14+
from ..util import get_batch, make_tempdir, add_vecs_to_vocab
1515

1616

1717
def test_empty_doc():
@@ -134,9 +134,25 @@ def test_init_tok2vec():
134134
]
135135

136136

137-
def test_tok2vec_listener():
137+
@pytest.mark.parametrize("with_vectors", (False, True))
138+
def test_tok2vec_listener(with_vectors):
138139
orig_config = Config().from_str(cfg_string)
140+
orig_config["components"]["tok2vec"]["model"]["embed"][
141+
"include_static_vectors"
142+
] = with_vectors
139143
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
144+
145+
if with_vectors:
146+
ops = get_current_ops()
147+
vectors = [
148+
("apple", ops.asarray([1, 2, 3])),
149+
("orange", ops.asarray([-1, -2, -3])),
150+
("and", ops.asarray([-1, -1, -1])),
151+
("juice", ops.asarray([5, 5, 10])),
152+
("pie", ops.asarray([7, 6.3, 8.9])),
153+
]
154+
add_vecs_to_vocab(nlp.vocab, vectors)
155+
140156
assert nlp.pipe_names == ["tok2vec", "tagger"]
141157
tagger = nlp.get_pipe("tagger")
142158
tok2vec = nlp.get_pipe("tok2vec")
@@ -163,6 +179,9 @@ def test_tok2vec_listener():
163179
ops = get_current_ops()
164180
assert_array_equal(ops.to_numpy(doc.tensor), ops.to_numpy(doc_tensor))
165181

182+
# test with empty doc
183+
doc = nlp("")
184+
166185
# TODO: should this warn or error?
167186
nlp.select_pipes(disable="tok2vec")
168187
assert nlp.pipe_names == ["tagger"]

0 commit comments

Comments
 (0)