Skip to content

Commit cabb0f0

Browse files
committed
Add sparsemax and SparsemaxSemiring
1 parent 2020dd0 commit cabb0f0

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

supar/structs/fn.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,4 +230,28 @@ def backward(ctx, grad_output):
230230
return None, None
231231

232232

233+
class SparsemaxFunction(Function):
234+
235+
@staticmethod
236+
def forward(ctx, x, dim=-1):
237+
ctx.dim = dim
238+
sorted_x, _ = x.sort(dim, True)
239+
z = sorted_x.cumsum(dim) - 1
240+
k = x.new_tensor(range(1, sorted_x.size(dim) + 1)).view(-1, *[1] * (x.dim() - 1)).transpose(0, dim)
241+
k = (k * sorted_x).gt(z).sum(dim, True)
242+
tau = z.gather(dim, k - 1) / k
243+
p = torch.clamp(x - tau, 0)
244+
ctx.save_for_backward(k, p)
245+
return p
246+
247+
@staticmethod
248+
def backward(ctx, grad_output):
249+
k, p, dim = *ctx.saved_tensors, ctx.dim
250+
grad = grad_output.masked_fill(p.eq(0), 0)
251+
grad = torch.where(p.ne(0), grad - grad.sum(dim, True) / k, grad)
252+
return grad, None
253+
254+
233255
sampled_logsumexp = SampledLogsumexp.apply
256+
257+
sparsemax = SparsemaxFunction.apply

supar/structs/semiring.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
from supar.utils.common import MIN
7-
from supar.structs.fn import sampled_logsumexp
7+
from supar.structs.fn import sampled_logsumexp, sparsemax
88

99

1010
class Semiring(object):
@@ -260,3 +260,15 @@ class SampledSemiring(LogSemiring):
260260
@classmethod
261261
def sum(cls, x, dim=-1):
262262
return sampled_logsumexp(x, dim)
263+
264+
265+
class SparsemaxSemiring(LogSemiring):
266+
r"""
267+
Sparsemax semiring :math:`<\mathrm{sparsemax}, +, -\infty, 0>`
268+
:cite:`martins-etal-2016-sparsemax,mensch-etal-2018-dp,correia-etal-2020-efficient`.
269+
"""
270+
271+
@staticmethod
272+
def sum(x, dim=-1):
273+
p = sparsemax(x, dim)
274+
return x.mul(p).sum(dim) - p.norm(p=2, dim=dim)

0 commit comments

Comments
 (0)