Skip to content

Commit 2020dd0

Browse files
committed
Reorganize modules
1 parent d7d3e8f commit 2020dd0

File tree

8 files changed

+108
-107
lines changed

8 files changed

+108
-107
lines changed

supar/utils/alg.py renamed to supar/structs/fn.py

Lines changed: 21 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,79 +3,7 @@
33
import torch
44
from supar.utils.common import MIN
55
from supar.utils.fn import pad
6-
7-
8-
def kmeans(x, k, max_it=32):
9-
r"""
10-
KMeans algorithm for clustering the sentences by length.
11-
12-
Args:
13-
x (list[int]):
14-
The list of sentence lengths.
15-
k (int):
16-
The number of clusters.
17-
This is an approximate value. The final number of clusters can be less or equal to `k`.
18-
max_it (int):
19-
Maximum number of iterations.
20-
If centroids does not converge after several iterations, the algorithm will be early stopped.
21-
22-
Returns:
23-
list[float], list[list[int]]:
24-
The first list contains average lengths of sentences in each cluster.
25-
The second is the list of clusters holding indices of data points.
26-
27-
Examples:
28-
>>> x = torch.randint(10,20,(10,)).tolist()
29-
>>> x
30-
[15, 10, 17, 11, 18, 13, 17, 19, 18, 14]
31-
>>> centroids, clusters = kmeans(x, 3)
32-
>>> centroids
33-
[10.5, 14.0, 17.799999237060547]
34-
>>> clusters
35-
[[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]]
36-
"""
37-
38-
# the number of clusters must not be greater than the number of datapoints
39-
x, k = torch.tensor(x, dtype=torch.float), min(len(x), k)
40-
# collect unique datapoints
41-
d = x.unique()
42-
# initialize k centroids randomly
43-
c = d[torch.randperm(len(d))[:k]]
44-
# assign each datapoint to the cluster with the closest centroid
45-
dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
46-
47-
for _ in range(max_it):
48-
# if an empty cluster is encountered,
49-
# choose the farthest datapoint from the biggest cluster and move that the empty one
50-
mask = torch.arange(k).unsqueeze(-1).eq(y)
51-
none = torch.where(~mask.any(-1))[0].tolist()
52-
while len(none) > 0:
53-
for i in none:
54-
# the biggest cluster
55-
b = torch.where(mask[mask.sum(-1).argmax()])[0]
56-
# the datapoint farthest from the centroid of cluster b
57-
f = dists[b].argmax()
58-
# update the assigned cluster of f
59-
y[b[f]] = i
60-
# re-calculate the mask
61-
mask = torch.arange(k).unsqueeze(-1).eq(y)
62-
none = torch.where(~mask.any(-1))[0].tolist()
63-
# update the centroids
64-
c, old = (x * mask).sum(-1) / mask.sum(-1), c
65-
# re-assign all datapoints to clusters
66-
dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
67-
# stop iteration early if the centroids converge
68-
if c.equal(old):
69-
break
70-
# assign all datapoints to the new-generated clusters
71-
# the empty ones are discarded
72-
assigned = y.unique().tolist()
73-
# get the centroids of the assigned clusters
74-
centroids = c[assigned].tolist()
75-
# map all values of datapoints to buckets
76-
clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned]
77-
78-
return centroids, clusters
6+
from torch.autograd import Function
797

808

819
def tarjan(sequence):
@@ -283,3 +211,23 @@ def mst(scores, mask, multiroot=False):
283211
preds.append(tree)
284212

285213
return pad(preds, total_length=seq_len).to(mask.device)
214+
215+
216+
class SampledLogsumexp(Function):
217+
218+
@staticmethod
219+
def forward(ctx, x, dim=-1):
220+
ctx.dim = dim
221+
ctx.save_for_backward(x)
222+
return x.logsumexp(dim=dim)
223+
224+
@staticmethod
225+
def backward(ctx, grad_output):
226+
from torch.distributions import OneHotCategorical
227+
x, dim = ctx.saved_tensors, ctx.dim
228+
if ctx.needs_input_grad[0]:
229+
return grad_output.unsqueeze(dim).mul(OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(-1, dim)), None
230+
return None, None
231+
232+
233+
sampled_logsumexp = SampledLogsumexp.apply

supar/structs/semiring.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
from supar.utils.common import MIN
7+
from supar.structs.fn import sampled_logsumexp
78

89

910
class Semiring(object):
@@ -140,8 +141,8 @@ def one_(cls, x):
140141

141142

142143
class EntropySemiring(LogSemiring):
143-
"""
144-
Entropy expectation semiring: :math:`<\oplus, +, [-\infty, 0], [0, 0]>`,
144+
r"""
145+
Entropy expectation semiring :math:`<\oplus, +, [-\infty, 0], [0, 0]>`,
145146
where :math:`\oplus` computes the log-values and the running distributional entropy :math:`H[p]`
146147
:cite:`li-eisner-2009-first,hwa-2000-sample,kim-etal-2019-unsupervised`.
147148
"""
@@ -177,8 +178,8 @@ def one_(cls, x):
177178

178179

179180
class CrossEntropySemiring(LogSemiring):
180-
"""
181-
Cross Entropy expectation semiring: :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
181+
r"""
182+
Cross Entropy expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
182183
where :math:`\oplus` computes the log-values and the running distributional cross entropy :math:`H[p,q]`
183184
of the two distributions :cite:`li-eisner-2009-first`.
184185
"""
@@ -214,14 +215,11 @@ def one_(cls, x):
214215

215216

216217
class KLDivergenceSemiring(LogSemiring):
217-
"""
218-
KL divergence expectation semiring: :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
218+
r"""
219+
KL divergence expectation semiring :math:`<\oplus, +, [-\infty, -\infty, 0], [0, 0, 0]>`,
219220
where :math:`\oplus` computes the log-values and the running distributional KL divergence :math:`KL[p \parallel q]`
220221
of the two distributions :cite:`li-eisner-2009-first`.
221222
"""
222-
"""
223-
KL divergence expectation semiring: `<logsumexp, +, -inf, 0>` :cite:`li-eisner-2009-first`.
224-
"""
225223

226224
@classmethod
227225
def convert(cls, x):
@@ -261,20 +259,4 @@ class SampledSemiring(LogSemiring):
261259

262260
@classmethod
263261
def sum(cls, x, dim=-1):
264-
return SampledLogsumexp.apply(x, dim)
265-
266-
267-
class SampledLogsumexp(torch.autograd.Function):
268-
269-
@staticmethod
270-
def forward(ctx, x, dim=-1):
271-
ctx.save_for_backward(x, torch.tensor(dim))
272-
return x.logsumexp(dim=dim)
273-
274-
@staticmethod
275-
def backward(ctx, grad_output):
276-
from torch.distributions import OneHotCategorical
277-
x, dim = ctx.saved_tensors
278-
if ctx.needs_input_grad[0]:
279-
return grad_output.unsqueeze(dim).mul(OneHotCategorical(logits=x.movedim(dim, -1)).sample().movedim(-1, dim)), None
280-
return None, None
262+
return sampled_logsumexp(x, dim)

supar/structs/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import torch
44
import torch.nn as nn
55
from supar.structs.dist import StructuredDistribution
6+
from supar.structs.fn import mst
67
from supar.structs.semiring import LogSemiring
7-
from supar.utils.alg import mst
88
from supar.utils.fn import stripe
99
from torch.distributions.utils import lazy_property
1010

supar/utils/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
from . import alg, field, fn, metric, transform
4-
from .alg import chuliu_edmonds, kmeans, mst, tarjan
3+
from . import field, fn, metric, transform
54
from .config import Config
65
from .data import Dataset
76
from .embedding import Embedding
@@ -10,5 +9,4 @@
109
from .vocab import Vocab
1110

1211
__all__ = ['ChartField', 'CoNLL', 'Config', 'Dataset', 'Embedding', 'Field',
13-
'RawField', 'SubwordField', 'Transform', 'Tree', 'Vocab',
14-
'alg', 'field', 'fn', 'metric', 'chuliu_edmonds', 'kmeans', 'mst', 'tarjan', 'transform']
12+
'RawField', 'SubwordField', 'Transform', 'Tree', 'Vocab', 'field', 'fn', 'metric', 'transform']

supar/utils/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
import torch.distributed as dist
5-
from supar.utils.alg import kmeans
5+
from supar.utils.fn import kmeans
66
from supar.utils.transform import Batch
77
from torch.utils.data import DataLoader
88

supar/utils/fn.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,79 @@ def tohalfwidth(token):
2929
return unicodedata.normalize('NFKC', token)
3030

3131

32+
def kmeans(x, k, max_it=32):
33+
r"""
34+
KMeans algorithm for clustering the sentences by length.
35+
36+
Args:
37+
x (list[int]):
38+
The list of sentence lengths.
39+
k (int):
40+
The number of clusters.
41+
This is an approximate value. The final number of clusters can be less or equal to `k`.
42+
max_it (int):
43+
Maximum number of iterations.
44+
If centroids does not converge after several iterations, the algorithm will be early stopped.
45+
46+
Returns:
47+
list[float], list[list[int]]:
48+
The first list contains average lengths of sentences in each cluster.
49+
The second is the list of clusters holding indices of data points.
50+
51+
Examples:
52+
>>> x = torch.randint(10,20,(10,)).tolist()
53+
>>> x
54+
[15, 10, 17, 11, 18, 13, 17, 19, 18, 14]
55+
>>> centroids, clusters = kmeans(x, 3)
56+
>>> centroids
57+
[10.5, 14.0, 17.799999237060547]
58+
>>> clusters
59+
[[1, 3], [0, 5, 9], [2, 4, 6, 7, 8]]
60+
"""
61+
62+
# the number of clusters must not be greater than the number of datapoints
63+
x, k = torch.tensor(x, dtype=torch.float), min(len(x), k)
64+
# collect unique datapoints
65+
d = x.unique()
66+
# initialize k centroids randomly
67+
c = d[torch.randperm(len(d))[:k]]
68+
# assign each datapoint to the cluster with the closest centroid
69+
dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
70+
71+
for _ in range(max_it):
72+
# if an empty cluster is encountered,
73+
# choose the farthest datapoint from the biggest cluster and move that the empty one
74+
mask = torch.arange(k).unsqueeze(-1).eq(y)
75+
none = torch.where(~mask.any(-1))[0].tolist()
76+
while len(none) > 0:
77+
for i in none:
78+
# the biggest cluster
79+
b = torch.where(mask[mask.sum(-1).argmax()])[0]
80+
# the datapoint farthest from the centroid of cluster b
81+
f = dists[b].argmax()
82+
# update the assigned cluster of f
83+
y[b[f]] = i
84+
# re-calculate the mask
85+
mask = torch.arange(k).unsqueeze(-1).eq(y)
86+
none = torch.where(~mask.any(-1))[0].tolist()
87+
# update the centroids
88+
c, old = (x * mask).sum(-1) / mask.sum(-1), c
89+
# re-assign all datapoints to clusters
90+
dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
91+
# stop iteration early if the centroids converge
92+
if c.equal(old):
93+
break
94+
# assign all datapoints to the new-generated clusters
95+
# the empty ones are discarded
96+
assigned = y.unique().tolist()
97+
# get the centroids of the assigned clusters
98+
centroids = c[assigned].tolist()
99+
# map all values of datapoints to buckets
100+
clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned]
101+
102+
return centroids, clusters
103+
104+
32105
def stripe(x, n, w, offset=(0, 0), dim=1):
33106
r"""
34107
Returns a diagonal stripe of the tensor.

supar/utils/transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def istree(cls, sequence, proj=False, multiroot=False):
292292
False
293293
"""
294294

295-
from supar.utils.alg import tarjan
295+
from supar.structs.fn import tarjan
296296
if proj and not cls.isprojective(sequence):
297297
return False
298298
n_roots = sum(head == 0 for head in sequence)

tests/test_alg.py renamed to tests/test_fn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
from supar.utils import tarjan
3+
from supar.structs.fn import tarjan
44

55

66
def test_tarjan():

0 commit comments

Comments
 (0)