Skip to content

extract statistics from attention weights in FlexAttention #159770

@HP2706

Description

@HP2706

🚀 The feature, motivation and pitch

Motivation

The current torch.nn.attention.flex_attention implementation only supports returning the logsumexp (LSE) of attention scores via the return_lse parameter. However, there are many other statistics a user might want to read from the attention weights, that can be computed without materialising the entire attention matrix.
For example if we want to implement MuonClip we need access to the max per head attention logits, this is not possible with the current implementation of flex_attention, and users would have to write their own kernel, to obtain this information.

Thus i think it would be nice to allow users to access parts of QK^T matrix outside the kernel.

API Design

One possible approach would be to allow the users to write a triton function specifying the reduction. And the dimensions to reduce over, IE kv, q, h, b.

@triton.jit 
def reduction(
    qk_ptr, # [BlockM, BlockN]
    accum_ptr, # [user specified dims]
    b_idx : int, 
    h_idx : int, 
    q_offs : int, 
    v_offs : int
): 

Or we might not want the user to write triton code but assume that the user only wants to compute a limited set of reductions like max, min, mean, sum.

class ReductionType(Enum):
    MAX = "max"
    MIN = "min"
    MEAN = "mean"
    SUM = "sum"

Alternatives

No response

Additional context

No response

cc @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: flex attentionmodule: higher order operatorstorch.cond and similarmodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions