diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index bed3c42473e27..a51709a33910e 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -11802,28 +11802,146 @@ SDValue DAGCombiner::foldShiftToAvg(SDNode *N, const SDLoc &DL) { return SDValue(); EVT VT = N->getValueType(0); - bool IsUnsigned = Opcode == ISD::SRL; + SDValue N0 = N->getOperand(0); - // Captured values. - SDValue A, B, Add; + if (!sd_match(N->getOperand(1), m_One())) + return SDValue(); - // Match floor average as it is common to both floor/ceil avgs. + // [TruncVT] + // result type of a single truncate user fed by this shift node (if present). + // We always use TruncVT to verify whether the target supports folding to + // avgceils. For avgfloor[su], we use TruncVT if present, else VT. + // + // [NarrowVT] + // semantic source width of the value(s) being averaged when the ops are + // SExt/SExtInReg. + EVT TruncVT = VT; + SDNode *TruncNode = nullptr; + + // If this shift has a single truncate user, use it to decide whether folding + // to avg* is legal at the truncated width. Note that the target may only + // support the avgceil[su]/avgfloor[su] op at the narrower type, or the + // full-width VT, but we check for legality using the truncate node's VT if + // present, else this shift's VT. + if (N->hasOneUse() && N->user_begin()->getOpcode() == ISD::TRUNCATE) { + TruncNode = *N->user_begin(); + TruncVT = TruncNode->getValueType(0); + } + + EVT NarrowVT = VT; + SDValue N00 = N0.getOperand(0); + + // For SRL of SExt'd values, if (1) the type isnt legal, and (2) there's no + // truncate user, bail out, because we can't safely fold. + if (N00.getOpcode() == ISD::SIGN_EXTEND_INREG) { + NarrowVT = cast(N0->getOperand(0)->getOperand(1))->getVT(); + if (Opcode == ISD::SRL && !TLI.isTypeLegal(NarrowVT)) + return SDValue(); + } + + unsigned FloorISD = 0; + unsigned CeilISD = 0; + bool IsUnsigned = false; + + // Decide whether signed or unsigned. + switch (Opcode) { + case ISD::SRA: + FloorISD = ISD::AVGFLOORS; + break; + case ISD::SRL: + IsUnsigned = true; + // SRL of a widened signed sub feeding a truncate acts like shadd. + if (TruncNode && + (N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::SUB) && + (N00.getOpcode() == ISD::SIGN_EXTEND_INREG || + N00.getOpcode() == ISD::SIGN_EXTEND)) + IsUnsigned = false; + FloorISD = (IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS); + break; + default: + return SDValue(); + } + + CeilISD = (IsUnsigned ? ISD::AVGCEILU : ISD::AVGCEILS); + + // Bail out if this shift is not truncated and the target doesn't support + // the avg* op at this shift's VT (or TruncVT for avgceil[su]). + if ((!TruncNode && !TLI.isOperationLegalOrCustom(FloorISD, VT)) || + (!TruncNode && !TLI.isOperationLegalOrCustom(CeilISD, TruncVT))) + return SDValue(); + + SDValue X, Y, Sub, Xor; + + // (sr[al] (sub x, (xor y, -1)), 1) -> (avgceil[su] x, y) if (sd_match(N, m_BinOp(Opcode, - m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))), + m_AllOf(m_Value(Sub), + m_Sub(m_Value(X), + m_AllOf(m_Value(Xor), + m_Xor(m_Value(Y), m_Value())))), m_One()))) { - // Decide whether signed or unsigned. - unsigned FloorISD = IsUnsigned ? ISD::AVGFLOORU : ISD::AVGFLOORS; - if (!hasOperation(FloorISD, VT)) - return SDValue(); + APInt SplatVal; + if (ISD::isConstantSplatVector(Xor.getOperand(1).getNode(), SplatVal)) { + // - Can't fold if either op is sign/zero-extended for SRL, as SRL + // is unsigned, and shadd patterns are handled elsewhere. + // + // - Large fixed vectors (>128 bits) on AArch64 will be type-legalized + // into a series of EXTRACT_SUBVECTORs. Folding each subvector does not + // necessarily preserve semantics so they cannot be folded here. + if (TruncNode && VT.isFixedLengthVector()) { + if (X.getOpcode() == ISD::SIGN_EXTEND || + X.getOpcode() == ISD::ZERO_EXTEND || + Y.getOpcode() == ISD::SIGN_EXTEND || + Y.getOpcode() == ISD::ZERO_EXTEND) + return SDValue(); + else if (TruncNode && VT.isFixedLengthVector() && + VT.getSizeInBits() > 128) + return SDValue(); + } - // Can't optimize adds that may wrap. - if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) || - (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())) - return SDValue(); + // If there is no truncate user, ensure the relevant no wrap flag is on + // the sub so that narrowing the widened result is defined. + if (Opcode == ISD::SRA && VT == NarrowVT) { + if (!IsUnsigned && !Sub->getFlags().hasNoSignedWrap()) + return SDValue(); + } else if (IsUnsigned && !Sub->getFlags().hasNoUnsignedWrap()) + return SDValue(); - return DAG.getNode(FloorISD, DL, N->getValueType(0), {A, B}); + // Only fold if the target supports avgceil[su] at the truncated type: + // - if there is a single truncate user, we require support at TruncVT. + // We build the avg* at VT (to replace this shift node). + // visitTRUNCATE handles the actual folding to avgceils (x, y). + // - otherwise, we require support at VT (TruncVT == VT). + // + // AArch64 canonicalizes (x + y + 1) >> 1 -> sub (x, xor (y, -1)). In + // order for our fold to be legal, we require support for the VT at the + // final observable type (TruncVT or VT). + if (TLI.isOperationLegalOrCustom(CeilISD, TruncVT)) + return DAG.getNode(CeilISD, DL, VT, Y, X); + } } + // Captured values. + SDValue A, B, Add; + + // Match floor average as it is common to both floor/ceil avgs. + // (sr[al] (add a, b), 1) -> avgfloor[su](a, b) + if (!sd_match(N, m_BinOp(Opcode, + m_AllOf(m_Value(Add), m_Add(m_Value(A), m_Value(B))), + m_One()))) + return SDValue(); + + if (TruncNode && VT.isFixedLengthVector() && VT.getSizeInBits() > 128) + return SDValue(); + + // Can't optimize adds that may wrap. + if ((IsUnsigned && !Add->getFlags().hasNoUnsignedWrap()) || + (!IsUnsigned && !Add->getFlags().hasNoSignedWrap())) + return SDValue(); + + EVT TargetVT = TruncNode ? TruncVT : VT; + if (TLI.isOperationLegalOrCustom(FloorISD, TargetVT)) + return DAG.getNode(FloorISD, DL, N->getValueType(0), A, B); + return SDValue(); } @@ -16294,6 +16412,23 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { } } + // trunc (avgceilu (sext (x), sext (y))) -> avgceils(x, y) + if (N0.getOpcode() == ISD::AVGCEILU) { + SDValue SExtX = N0.getOperand(0); + SDValue SExtY = N0.getOperand(1); + if ((SExtX.getOpcode() == ISD::SIGN_EXTEND && + SExtY.getOpcode() == ISD::SIGN_EXTEND) || + (SExtX.getOpcode() == ISD::SIGN_EXTEND_INREG && + SExtY.getOpcode() == ISD::SIGN_EXTEND_INREG)) { + SDValue X = SExtX.getOperand(0); + SDValue Y = SExtY.getOperand(0); + if (X.getValueType() == VT && + TLI.isOperationLegalOrCustom(ISD::AVGCEILS, VT)) { + return DAG.getNode(ISD::AVGCEILS, DL, VT, X, Y); + } + } + } + if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) return NewVSel;