Skip to content

Conversation

laurenmchin
Copy link

@laurenmchin laurenmchin commented Aug 31, 2025

  • Match avgceil[su] idiom lowered by AArch64: (sr[al] (sub x, (xor y, -1)), 1) -> avgceil[su](x, y)
  • Keep floor-average fold, but check legality at TruncVT when truncation is present: (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
  • Treat SRL of widened signed add/sub feeding a truncate as signed
  • Add visitTRUNCATE combine: trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)

This patch resolves the regression test failure:
llvm/test/CodeGen/AArch64/sve-hadd.ll and fixes #147946.

(llvm#147946)

- Match avgceil[su] idiom lowered by AArch64:
    (sr[al] (sub x, (xor y, -1)), 1) -> avgceil[su](x, y)
- Keep floor-average fold, but check legality at TruncVT when truncation
  is present:
    (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
- Treat SRL of widened signed add/sub feeding a truncate as signed
- Add visitTRUNCATE combine:
    trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)

This patch resolves the regression test failure:
  llvm/test/CodeGen/AArch64/sve-hadd.ll
and addresses PR llvm#147946.
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Aug 31, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 31, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Lauren (laurenmchin)

Changes
  • Match avgceil[su] idiom lowered by AArch64: (sr[al] (sub x, (xor y, -1)), 1) -> avgceil[su](x, y)
  • Keep floor-average fold, but check legality at TruncVT when truncation is present: (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
  • Treat SRL of widened signed add/sub feeding a truncate as signed
  • Add visitTRUNCATE combine: trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)

This patch resolves the regression test failure:
llvm/test/CodeGen/AArch64/sve-hadd.ll and addresses PR #147946.


Full diff: https://github.com/llvm/llvm-project/pull/156239.diff

1 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+149-14)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bed3c42473e27..a51709a33910e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11802,28 +11802,146 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) {
     return SDValue();
 
   EVT VT = N->getValueType(0);
-  bool IsUnsigned = Opcode == ISD::SRL;
+  SDValue N0 = N->getOperand(0);
 
-  // Captured values.
-  SDValue A, B, Add;
+  if (!sd_match(N->getOperand(1), m_One()))
+    return SDValue();
 
-  // Match floor average as it is common to both floor/ceil avgs.
+  // [TruncVT]
+  // result type of a single truncate user fed by this shift node (if present).
+  // We always use TruncVT to verify whether the target supports folding to
+  // avgceils. For avgfloor[su], we use TruncVT if present, else VT.
+  //
+  // [NarrowVT]
+  // semantic source width of the value(s) being averaged when the ops are
+  // SExt/SExtInReg.
+  EVT TruncVT = VT;
+  SDNode *TruncNode = nullptr;
+
+  // If this shift has a single truncate user, use it to decide whether folding
+  // to avg* is legal at the truncated width. Note that the target may only
+  // support the avgceil[su]/avgfloor[su] op at the narrower type, or the
+  // full-width VT, but we check for legality using the truncate node's VT if
+  // present, else this shift's VT.
+  if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) {
+    TruncNode = *N->user_begin();
+    TruncVT = TruncNode->getValueType(0);
+  }
+
+  EVT NarrowVT = VT;
+  SDValue N00 = N0.getOperand(0);
+
+  // For SRL of SExt'd values, if (1) the type isnt legal, and (2) there's no
+  // truncate user, bail out, because we can't safely fold.
+  if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) {
+    NarrowVT = cast<VTSDNode>(N0->getOperand(0)->getOperand(1))->getVT();
+    if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT))
+      return SDValue();
+  }
+
+  unsigned FloorISD = 0;
+  unsigned CeilISD = 0;
+  bool IsUnsigned = false;
+
+  // Decide whether signed or unsigned.
+  switch (Opcode) {
+  case ISD::SRA:
+    FloorISD = ISD::AVGFLOORS;
+    break;
+  case ISD::SRL:
+    IsUnsigned = true;
+    // SRL of a widened signed sub feeding a truncate acts like shadd.
+    if (TruncNode &&
+        (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) &&
+        (N00.getOpcode() == ISD::SIGN_EXTEND_INREG ||
+         N00.getOpcode() == ISD::SIGN_EXTEND))
+      IsUnsigned = false;
+    FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS);
+    break;
+  default:
+    return SDValue();
+  }
+
+  CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS);
+
+  // Bail out if this shift is not truncated and the target doesn't support
+  // the avg* op at this shift's VT (or TruncVT for avgceil[su]).
+  if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) ||
+      (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT)))
+    return SDValue();
+
+  SDValue X, Y, Sub, Xor;
+
+  // (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y)
   if (sd_match(N, m_BinOp(Opcode,
-                          m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                          m_AllOf(m_Value(Sub),
+                                  m_Sub(m_Value(X),
+                                        m_AllOf(m_Value(Xor),
+                                                m_Xor(m_Value(Y), m_Value())))),
                           m_One()))) {
-    // Decide whether signed or unsigned.
-    unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS;
-    if (!hasOperation(FloorISD, VT))
-      return SDValue();
+    APInt SplatVal;
+    if (ISD::isConstantSplatVector(Xor.getOperand(1).getNode(), SplatVal)) {
+      // - Can't fold if either op is sign/zero-extended for SRL, as SRL
+      //   is unsigned, and shadd patterns are handled elsewhere.
+      //
+      // - Large fixed vectors (>128 bits) on AArch64 will be type-legalized
+      //   into a series of EXTRACT_SUBVECTORs. Folding each subvector does not
+      //   necessarily preserve semantics so they cannot be folded here.
+      if (TruncNode && VT.isFixedLengthVector()) {
+        if (X.getOpcode() == ISD::SIGN_EXTEND ||
+            X.getOpcode() == ISD::ZERO_EXTEND ||
+            Y.getOpcode() == ISD::SIGN_EXTEND ||
+            Y.getOpcode() == ISD::ZERO_EXTEND)
+          return SDValue();
+        else if (TruncNode && VT.isFixedLengthVector() &&
+                 VT.getSizeInBits() > 128)
+          return SDValue();
+      }
 
-    // Can't optimize adds that may wrap.
-    if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
-        (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
-      return SDValue();
+      // If there is no truncate user, ensure the relevant no wrap flag is on
+      // the sub so that narrowing the widened result is defined.
+      if (Opcode == ISD::SRA && VT == NarrowVT) {
+        if (!IsUnsigned && !Sub->getFlags().hasNoSignedWrap())
+          return SDValue();
+      } else if (IsUnsigned && !Sub->getFlags().hasNoUnsignedWrap())
+        return SDValue();
 
-    return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B});
+      // Only fold if the target supports avgceil[su] at the truncated type:
+      // - if there is a single truncate user, we require support at TruncVT.
+      //   We build the avg* at VT (to replace this shift node).
+      //   visitTRUNCATE handles the actual folding to avgceils (x, y).
+      // - otherwise, we require support at VT (TruncVT == VT).
+      //
+      // AArch64 canonicalizes (x + y + 1) >> 1 -> sub (x, xor (y, -1)). In
+      // order for our fold to be legal, we require support for the VT at the
+      // final observable type (TruncVT or VT).
+      if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT))
+        return DAG.getNode(CeilISD, DL, VT, Y, X);
+    }
   }
 
+  // Captured values.
+  SDValue A, B, Add;
+
+  // Match floor average as it is common to both floor/ceil avgs.
+  // (sr[al] (add a, b), 1) -> avgfloor[su](a, b)
+  if (!sd_match(N, m_BinOp(Opcode,
+                           m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))),
+                           m_One())))
+    return SDValue();
+
+  if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128)
+    return SDValue();
+
+  // Can't optimize adds that may wrap.
+  if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) ||
+      (!IsUnsigned && !Add->getFlags().hasNoSignedWrap()))
+    return SDValue();
+
+  EVT TargetVT = TruncNode ? TruncVT : VT;
+  if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT))
+    return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B);
+
   return SDValue();
 }
 
@@ -16294,6 +16412,23 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
     }
   }
 
+  // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
+  if (N0.getOpcode() == ISD::AVGCEILU) {
+    SDValue SExtX = N0.getOperand(0);
+    SDValue SExtY = N0.getOperand(1);
+    if ((SExtX.getOpcode() == ISD::SIGN_EXTEND &&
+         SExtY.getOpcode() == ISD::SIGN_EXTEND) ||
+        (SExtX.getOpcode() == ISD::SIGN_EXTEND_INREG &&
+         SExtY.getOpcode() == ISD::SIGN_EXTEND_INREG)) {
+      SDValue X = SExtX.getOperand(0);
+      SDValue Y = SExtY.getOperand(0);
+      if (X.getValueType() == VT &&
+          TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT)) {
+        return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y);
+      }
+    }
+  }
+
   if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
     return NewVSel;
 

@laurenmchin
Copy link
Author

@RKSimon requesting a review when you have time, thanks

@RKSimon RKSimon self-requested a review September 2, 2025 10:31
@@ -16294,6 +16412,23 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
}
}

// trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you move this down to the switch statement below (handle avgceilu first and [[fallthrough]] to share the common code with the other avg/abd opcodes.

You also need to create an alive2 test for this.

Ideally we should add this fold first in a separate PR - you need to create some suitable test coverage and create a separate PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of tests along the lines of this in hadd-combine.ll should be enough:

; trunc(avgceilu(sext x, sext y)) -> avgceils(x, y)
define <8 x i8> @trunc_urhadd_sext(<8 x i8> %a0, <8 x i8> %a1) {
  %x0 = sext <8 x i8> %a0 to <8 x i16>
  %x1 = sext <8 x i8> %a1 to <8 x i16>
  %avg = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %x0, <8 x i16> %x1)
  %res = trunc <8 x i16> %avg to <8 x i8>
  ret <8 x i8> %res
}

You will still need an alive2 test

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please can you move this down to the switch statement below (handle avgceilu first and [[fallthrough]] to share the common code with the other avg/abd opcodes.

You also need to create an alive2 test for this.

Ideally we should add this fold first in a separate PR - you need to create some suitable test coverage and create a separate PR.

thanks, will do. i’ll work on moving the logic under the switch and submit a separate PR with test coverage + an alive2 proof for the fold. will follow up soon.

if ((SExtX.getOpcode() == ISD::SIGN_EXTEND &&
SExtY.getOpcode() == ISD::SIGN_EXTEND) ||
(SExtX.getOpcode() == ISD::SIGN_EXTEND_INREG &&
SExtY.getOpcode() == ISD::SIGN_EXTEND_INREG)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if SIGN_EXTEND_INREG will work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DAG] combineShiftToAVG - failure to match avgceil patterns on preferIncOfAddToSubOfNot targets
3 participants