Skip to content

Commit aa682bb

Browse files
committed
Avoid CUDA stream sync
Signed-off-by: cyy <cyyever@outlook.com>
1 parent f4d57f2 commit aa682bb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,12 @@ def _prepare_from_posids(query, key, value, position_ids, query_length):
243243
cu_seq_lens_q = torch.cat([torch.zeros(1, **tensor_kws), q_len.cumsum(0).to(torch.int32)], 0)
244244
max_length_q = int(q_len.max())
245245
else:
246-
position_ids = position_ids.flatten()
247-
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
246+
position_ids = position_ids.view(-1)
247+
indices_q = (position_ids == 0).nonzero().view(-1)
248248

249249
cu_seq_lens_q = torch.cat(
250250
(
251-
indices_q[position_ids == 0],
251+
indices_q,
252252
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
253253
)
254254
)

0 commit comments

Comments
 (0)