@@ -15,9 +15,7 @@ def kmeans(x, k):
15
15
# assign labels to each datapoint based on centroids
16
16
dists , y = torch .abs_ (d .unsqueeze (- 1 ) - c ).min (dim = - 1 )
17
17
# make sure number of datapoints is greater than that of clusters
18
- if len (d ) < k :
19
- raise AssertionError (f"unable to assign { len (d )} datapoints to "
20
- f"{ k } clusters" )
18
+ assert len (d ) >= k , f"unable to assign { len (d )} datapoints to { k } clusters"
21
19
22
20
while old is None or not c .equal (old ):
23
21
# if an empty cluster is encountered,
@@ -59,27 +57,27 @@ def eisner(scores, mask):
59
57
for w in range (1 , seq_len ):
60
58
n = seq_len - w
61
59
starts = p_i .new_tensor (range (n )).unsqueeze (0 )
62
- # ilr = C(i, r) + C(j, r+1)
60
+ # ilr = C(i-> r) + C(j-> r+1)
63
61
ilr = stripe (s_c , n , w ) + stripe (s_c , n , w , (w , 1 ))
64
62
# [batch_size, n, w]
65
63
ilr = ilr .permute (2 , 0 , 1 )
66
64
il = ilr + scores .diagonal (- w ).unsqueeze (- 1 )
67
- # I(j, i) = max(C(i, r) + C(j, r+1) + S(j, i)), i <= r < j
65
+ # I(j-> i) = max(C(i-> r) + C(j-> r+1) + s(j-> i)), i <= r < j
68
66
il_span , il_path = il .max (- 1 )
69
67
s_i .diagonal (- w ).copy_ (il_span )
70
68
p_i .diagonal (- w ).copy_ (il_path + starts )
71
69
ir = ilr + scores .diagonal (w ).unsqueeze (- 1 )
72
- # I(i, j) = max(C(i, r) + C(j, r+1) + S(i, j)), i <= r < j
70
+ # I(i-> j) = max(C(i-> r) + C(j-> r+1) + s(i-> j)), i <= r < j
73
71
ir_span , ir_path = ir .max (- 1 )
74
72
s_i .diagonal (w ).copy_ (ir_span )
75
73
p_i .diagonal (w ).copy_ (ir_path + starts )
76
74
77
- # C(j, i) = max(C(r, i) + I(j, r)), i <= r < j
78
- cl = stripe (s_c , n , w , dim = 0 ) + stripe (s_i , n , w , (w , 0 ))
75
+ # C(j-> i) = max(C(r-> i) + I(j-> r)), i <= r < j
76
+ cl = stripe (s_c , n , w , ( 0 , 0 ), 0 ) + stripe (s_i , n , w , (w , 0 ))
79
77
cl_span , cl_path = cl .permute (2 , 0 , 1 ).max (- 1 )
80
78
s_c .diagonal (- w ).copy_ (cl_span )
81
79
p_c .diagonal (- w ).copy_ (cl_path + starts )
82
- # C(i, j) = max(I(i, r) + C(r, j)), i < r <= j
80
+ # C(i-> j) = max(I(i-> r) + C(r-> j)), i < r <= j
83
81
cr = stripe (s_i , n , w , (0 , 1 )) + stripe (s_c , n , w , (1 , w ), 0 )
84
82
cr_span , cr_path = cr .permute (2 , 0 , 1 ).max (- 1 )
85
83
s_c .diagonal (w ).copy_ (cr_span )
@@ -136,7 +134,7 @@ def stripe(x, n, w, offset=(0, 0), dim=1):
136
134
tensor([[ 0, 5, 10],
137
135
[ 6, 11, 16]])
138
136
'''
139
- seq_len = x .size (1 )
137
+ x , seq_len = x . contiguous (), x .size (1 )
140
138
stride , numel = list (x .stride ()), x [0 , 0 ].numel ()
141
139
stride [0 ] = (seq_len + 1 ) * numel
142
140
stride [1 ] = (1 if dim == 1 else seq_len ) * numel
0 commit comments