Skip to content

Commit 838f501

Browse files
committed
Black formatting
1 parent 2a8efda commit 838f501

File tree

3 files changed

+24
-41
lines changed

3 files changed

+24
-41
lines changed

spacy/ml/models/coref.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,7 @@ def build_wl_coref_model(
4646
return coref_model
4747

4848

49-
def convert_coref_clusterer_inputs(
50-
model: Model,
51-
X: List[Floats2d],
52-
is_train: bool
53-
):
49+
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
5450
# The input here is List[Floats2d], one for each doc
5551
# just use the first
5652
# TODO real batching
@@ -62,7 +58,7 @@ def backprop(args: ArgsKwargs) -> List[Floats2d]:
6258
gradients = torch2xp(args.args[0])
6359
return [gradients]
6460

65-
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
61+
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
6662

6763

6864
def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
@@ -115,12 +111,7 @@ def __init__(
115111
# Modules
116112
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
117113
pair_emb = dim * 3 + self.pw.shape
118-
self.a_scorer = AnaphoricityScorer(
119-
pair_emb,
120-
hidden_size,
121-
n_layers,
122-
dropout
123-
)
114+
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
124115
self.lstm = torch.nn.LSTM(
125116
input_size=dim,
126117
hidden_size=dim,
@@ -129,13 +120,9 @@ def __init__(
129120
self.rough_scorer = RoughScorer(dim, dropout, roughk)
130121
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
131122
pair_emb = dim * 3 + self.pw.shape
132-
self.a_scorer = AnaphoricityScorer(
133-
pair_emb, hidden_size, n_layers, dropout
134-
)
123+
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
135124

136-
def forward(
137-
self, word_features: torch.Tensor
138-
) -> Tuple[torch.Tensor, torch.Tensor]:
125+
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
139126
"""
140127
1. LSTM encodes the incoming word_features.
141128
2. The RoughScorer scores and prunes the candidates.
@@ -350,13 +337,9 @@ def __init__(self, distance_embedding_size, dropout):
350337
self.dropout = torch.nn.Dropout(dropout)
351338
self.shape = emb_size
352339

353-
def forward(
354-
self,
355-
top_indices: torch.Tensor
356-
) -> torch.Tensor:
340+
def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
357341
word_ids = torch.arange(0, top_indices.size(0))
358-
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
359-
).clamp_min_(min=1)
342+
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
360343
log_distance = distance.to(torch.float).log2().floor_()
361344
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
362345
distance = torch.where(distance < 5, distance - 1, log_distance + 2)

spacy/ml/models/coref_util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
99

10+
1011
class GraphNode:
1112
def __init__(self, node_id: int):
1213
self.id = node_id
@@ -30,6 +31,7 @@ def get_sentence_ids(doc):
3031
out.append(sent_id)
3132
return out
3233

34+
3335
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
3436
"""Given a doc, give the mention clusters.
3537
@@ -100,7 +102,6 @@ def get_predicted_clusters(
100102
return predicted_clusters
101103

102104

103-
104105
def select_non_crossing_spans(
105106
idxs: List[int], starts: List[int], ends: List[int], limit: int
106107
) -> List[int]:
@@ -150,23 +151,25 @@ def select_non_crossing_spans(
150151
# selected.append(selected[0]) # this seems a bit weird?
151152
return selected
152153

154+
153155
def create_head_span_idxs(ops, doclen: int):
154156
"""Helper function to create single-token span indices."""
155157
aa = ops.xp.arange(0, doclen)
156158
bb = ops.xp.arange(0, doclen) + 1
157159
return ops.asarray2i([aa, bb]).T
158160

161+
159162
def get_clusters_from_doc(doc) -> List[List[Tuple[int, int]]]:
160163
"""Given a Doc, convert the cluster spans to simple int tuple lists."""
161164
out = []
162165
for key, val in doc.spans.items():
163166
cluster = []
164167
for span in val:
165168
# TODO check that there isn't an off-by-one error here
166-
#cluster.append((span.start, span.end))
169+
# cluster.append((span.start, span.end))
167170
# TODO This conversion should be happening earlier in processing
168171
head_i = span.root.i
169-
cluster.append( (head_i, head_i + 1) )
172+
cluster.append((head_i, head_i + 1))
170173

171174
# don't want duplicates
172175
cluster = list(set(cluster))

spacy/ml/models/span_predictor.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def build_span_predictor(
1818
conv_channels: int = 4,
1919
window_size: int = 1,
2020
max_distance: int = 128,
21-
prefix: str = "coref_head_clusters"
21+
prefix: str = "coref_head_clusters",
2222
):
2323
# TODO add model return types
2424
# TODO fix this
@@ -36,7 +36,7 @@ def build_span_predictor(
3636
distance_embedding_size,
3737
conv_channels,
3838
window_size,
39-
max_distance
39+
max_distance,
4040
),
4141
convert_inputs=convert_span_predictor_inputs,
4242
)
@@ -133,20 +133,17 @@ def head_data_forward(model, docs, is_train):
133133
# TODO this should maybe have a different name from the component
134134
class SpanPredictor(torch.nn.Module):
135135
def __init__(
136-
self,
137-
input_size: int,
138-
hidden_size: int,
139-
dist_emb_size: int,
140-
conv_channels: int,
141-
window_size: int,
142-
max_distance: int
143-
136+
self,
137+
input_size: int,
138+
hidden_size: int,
139+
dist_emb_size: int,
140+
conv_channels: int,
141+
window_size: int,
142+
max_distance: int,
144143
):
145144
super().__init__()
146145
if max_distance % 2 != 0:
147-
raise ValueError(
148-
"max_distance has to be an even number"
149-
)
146+
raise ValueError("max_distance has to be an even number")
150147
# input size = single token size
151148
# 64 = probably distance emb size
152149
# TODO check that dist_emb_size use is correct
@@ -164,7 +161,7 @@ def __init__(
164161
kernel_size = window_size * 2 + 1
165162
self.conv = torch.nn.Sequential(
166163
torch.nn.Conv1d(dist_emb_size, conv_channels, kernel_size, 1, 1),
167-
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1)
164+
torch.nn.Conv1d(conv_channels, 2, kernel_size, 1, 1),
168165
)
169166
# TODO make embeddings size a parameter
170167
self.max_distance = max_distance

0 commit comments

Comments
 (0)