diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index a248eb7444b20..72594c7f9783c 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3249,10 +3249,11 @@ class LLVM_ABI TargetLoweringBase { /// Return true on success. Currently only supports /// llvm.vector.deinterleave{2,3,5,7} /// - /// \p LI is the accompanying load instruction. + /// \p Load is the accompanying load instruction. Can be either a plain load + /// instruction or a vp.load intrinsic. /// \p DeinterleaveValues contains the deinterleaved values. virtual bool - lowerDeinterleaveIntrinsicToLoad(LoadInst *LI, + lowerDeinterleaveIntrinsicToLoad(Instruction *Load, Value *Mask, ArrayRef DeinterleaveValues) const { return false; } diff --git a/llvm/lib/CodeGen/InterleavedAccessPass.cpp b/llvm/lib/CodeGen/InterleavedAccessPass.cpp index 7259834975cf4..95599837e1bfc 100644 --- a/llvm/lib/CodeGen/InterleavedAccessPass.cpp +++ b/llvm/lib/CodeGen/InterleavedAccessPass.cpp @@ -634,24 +634,18 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic( if (!LastFactor) return false; + Value *Mask = nullptr; if (auto *VPLoad = dyn_cast(LoadedVal)) { if (VPLoad->getIntrinsicID() != Intrinsic::vp_load) return false; // Check mask operand. Handle both all-true/false and interleaved mask. Value *WideMask = VPLoad->getOperand(1); - Value *Mask = - getMask(WideMask, Factor, cast(LastFactor->getType())); + Mask = getMask(WideMask, Factor, cast(LastFactor->getType())); if (!Mask) return false; LLVM_DEBUG(dbgs() << "IA: Found a vp.load with deinterleave intrinsic " << *DI << " and factor = " << Factor << "\n"); - - // Since lowerInterleaveLoad expects Shuffles and LoadInst, use special - // TLI function to emit target-specific interleaved instruction. - if (!TLI->lowerInterleavedVPLoad(VPLoad, Mask, DeinterleaveValues)) - return false; - } else { auto *LI = cast(LoadedVal); if (!LI->isSimple()) @@ -659,12 +653,13 @@ bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic( LLVM_DEBUG(dbgs() << "IA: Found a load with deinterleave intrinsic " << *DI << " and factor = " << Factor << "\n"); - - // Try and match this with target specific intrinsics. - if (!TLI->lowerDeinterleaveIntrinsicToLoad(LI, DeinterleaveValues)) - return false; } + // Try and match this with target specific intrinsics. + if (!TLI->lowerDeinterleaveIntrinsicToLoad(cast(LoadedVal), Mask, + DeinterleaveValues)) + return false; + for (Value *V : DeinterleaveValues) if (V) DeadInsts.insert(cast(V)); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index bde4ba993f69e..235df9022c6fb 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17476,12 +17476,17 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI, } bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( - LoadInst *LI, ArrayRef DeinterleavedValues) const { + Instruction *Load, Value *Mask, + ArrayRef DeinterleavedValues) const { unsigned Factor = DeinterleavedValues.size(); if (Factor != 2 && Factor != 4) { LLVM_DEBUG(dbgs() << "Matching ld2 and ld4 patterns failed\n"); return false; } + auto *LI = dyn_cast(Load); + if (!LI) + return false; + assert(!Mask && "Unexpected mask on a load\n"); Value *FirstActive = *llvm::find_if(DeinterleavedValues, [](Value *V) { return V != nullptr; }); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 65fe08e92c235..6afb3c330d25b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -219,7 +219,8 @@ class AArch64TargetLowering : public TargetLowering { unsigned Factor) const override; bool lowerDeinterleaveIntrinsicToLoad( - LoadInst *LI, ArrayRef DeinterleaveValues) const override; + Instruction *Load, Value *Mask, + ArrayRef DeinterleaveValues) const override; bool lowerInterleaveIntrinsicToStore( StoreInst *SI, ArrayRef InterleaveValues) const override; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 00e969056df7d..41bbf6b9dcf2e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -438,7 +438,8 @@ class RISCVTargetLowering : public TargetLowering { unsigned Factor) const override; bool lowerDeinterleaveIntrinsicToLoad( - LoadInst *LI, ArrayRef DeinterleaveValues) const override; + Instruction *Load, Value *Mask, + ArrayRef DeinterleaveValues) const override; bool lowerInterleaveIntrinsicToStore( StoreInst *SI, ArrayRef InterleaveValues) const override; diff --git a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp index 440857f831fa6..ddfacd970e950 100644 --- a/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp +++ b/llvm/lib/Target/RISCV/RISCVInterleavedAccess.cpp @@ -234,53 +234,100 @@ bool RISCVTargetLowering::lowerInterleavedStore(StoreInst *SI, return true; } +static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) { + assert(N); + if (N == 1) + return true; + + using namespace PatternMatch; + // Right now we're only recognizing the simplest pattern. + uint64_t C; + if (match(V, m_CombineOr(m_ConstantInt(C), + m_c_Mul(m_Value(), m_ConstantInt(C)))) && + C && C % N == 0) + return true; + + if (isPowerOf2_32(N)) { + KnownBits KB = llvm::computeKnownBits(V, DL); + return KB.countMinTrailingZeros() >= Log2_32(N); + } + + return false; +} + bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad( - LoadInst *LI, ArrayRef DeinterleaveValues) const { + Instruction *Load, Value *Mask, + ArrayRef DeinterleaveValues) const { const unsigned Factor = DeinterleaveValues.size(); if (Factor > 8) return false; - assert(LI->isSimple()); - IRBuilder<> Builder(LI); + IRBuilder<> Builder(Load); Value *FirstActive = *llvm::find_if(DeinterleaveValues, [](Value *V) { return V != nullptr; }); VectorType *ResVTy = cast(FirstActive->getType()); - const DataLayout &DL = LI->getDataLayout(); + const DataLayout &DL = Load->getDataLayout(); + auto *XLenTy = Type::getIntNTy(Load->getContext(), Subtarget.getXLen()); - if (!isLegalInterleavedAccessType(ResVTy, Factor, LI->getAlign(), - LI->getPointerAddressSpace(), DL)) + Value *Ptr, *VL; + Align Alignment; + if (auto *LI = dyn_cast(Load)) { + assert(LI->isSimple()); + Ptr = LI->getPointerOperand(); + Alignment = LI->getAlign(); + assert(!Mask && "Unexpected mask on a load\n"); + Mask = Builder.getAllOnesMask(ResVTy->getElementCount()); + VL = isa(ResVTy) + ? Builder.CreateElementCount(XLenTy, ResVTy->getElementCount()) + : Constant::getAllOnesValue(XLenTy); + } else { + auto *VPLoad = cast(Load); + assert(VPLoad->getIntrinsicID() == Intrinsic::vp_load && + "Unexpected intrinsic"); + Ptr = VPLoad->getMemoryPointerParam(); + Alignment = VPLoad->getPointerAlignment().value_or( + DL.getABITypeAlign(ResVTy->getElementType())); + + assert(Mask && "vp.load needs a mask!"); + + Value *WideEVL = VPLoad->getVectorLengthParam(); + // Conservatively check if EVL is a multiple of factor, otherwise some + // (trailing) elements might be lost after the transformation. + if (!isMultipleOfN(WideEVL, Load->getDataLayout(), Factor)) + return false; + + VL = Builder.CreateZExt( + Builder.CreateUDiv(WideEVL, + ConstantInt::get(WideEVL->getType(), Factor)), + XLenTy); + } + + Type *PtrTy = Ptr->getType(); + unsigned AS = PtrTy->getPointerAddressSpace(); + if (!isLegalInterleavedAccessType(ResVTy, Factor, Alignment, AS, DL)) return false; Value *Return; - Type *PtrTy = LI->getPointerOperandType(); - Type *XLenTy = Type::getIntNTy(LI->getContext(), Subtarget.getXLen()); - if (isa(ResVTy)) { - Value *VL = Builder.CreateElementCount(XLenTy, ResVTy->getElementCount()); - Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount()); Return = Builder.CreateIntrinsic(FixedVlsegIntrIds[Factor - 2], - {ResVTy, PtrTy, XLenTy}, - {LI->getPointerOperand(), Mask, VL}); + {ResVTy, PtrTy, XLenTy}, {Ptr, Mask, VL}); } else { unsigned SEW = DL.getTypeSizeInBits(ResVTy->getElementType()); unsigned NumElts = ResVTy->getElementCount().getKnownMinValue(); Type *VecTupTy = TargetExtType::get( - LI->getContext(), "riscv.vector.tuple", - ScalableVectorType::get(Type::getInt8Ty(LI->getContext()), + Load->getContext(), "riscv.vector.tuple", + ScalableVectorType::get(Type::getInt8Ty(Load->getContext()), NumElts * SEW / 8), Factor); - Value *VL = Constant::getAllOnesValue(XLenTy); - Value *Mask = Builder.getAllOnesMask(ResVTy->getElementCount()); - Function *VlsegNFunc = Intrinsic::getOrInsertDeclaration( - LI->getModule(), ScalableVlsegIntrIds[Factor - 2], + Load->getModule(), ScalableVlsegIntrIds[Factor - 2], {VecTupTy, PtrTy, Mask->getType(), VL->getType()}); Value *Operands[] = { PoisonValue::get(VecTupTy), - LI->getPointerOperand(), + Ptr, Mask, VL, ConstantInt::get(XLenTy, @@ -290,7 +337,7 @@ bool RISCVTargetLowering::lowerDeinterleaveIntrinsicToLoad( CallInst *Vlseg = Builder.CreateCall(VlsegNFunc, Operands); SmallVector AggrTypes{Factor, ResVTy}; - Return = PoisonValue::get(StructType::get(LI->getContext(), AggrTypes)); + Return = PoisonValue::get(StructType::get(Load->getContext(), AggrTypes)); for (unsigned i = 0; i < Factor; ++i) { Value *VecExtract = Builder.CreateIntrinsic( Intrinsic::riscv_tuple_extract, {ResVTy, VecTupTy}, @@ -370,27 +417,6 @@ bool RISCVTargetLowering::lowerInterleaveIntrinsicToStore( return true; } -static bool isMultipleOfN(const Value *V, const DataLayout &DL, unsigned N) { - assert(N); - if (N == 1) - return true; - - using namespace PatternMatch; - // Right now we're only recognizing the simplest pattern. - uint64_t C; - if (match(V, m_CombineOr(m_ConstantInt(C), - m_c_Mul(m_Value(), m_ConstantInt(C)))) && - C && C % N == 0) - return true; - - if (isPowerOf2_32(N)) { - KnownBits KB = llvm::computeKnownBits(V, DL); - return KB.countMinTrailingZeros() >= Log2_32(N); - } - - return false; -} - /// Lower an interleaved vp.load into a vlsegN intrinsic. /// /// E.g. Lower an interleaved vp.load (Factor = 2):