diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp index 21cf9cc6878fb..cab6b7da2dfff 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp @@ -29,6 +29,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/LiveRegMatrix.h" +#include "llvm/CodeGen/LiveStacks.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/VirtRegMap.h" #include "llvm/InitializePasses.h" @@ -42,6 +43,9 @@ namespace { STATISTIC(NumMFMAsRewrittenToAGPR, "Number of MFMA instructions rewritten to use AGPR form"); +/// Map from spill slot frame index to list of instructions which reference it. +using SpillReferenceMap = DenseMap>; + class AMDGPURewriteAGPRCopyMFMAImpl { MachineFunction &MF; const GCNSubtarget &ST; @@ -51,6 +55,7 @@ class AMDGPURewriteAGPRCopyMFMAImpl { VirtRegMap &VRM; LiveRegMatrix &LRM; LiveIntervals &LIS; + LiveStacks &LSS; const RegisterClassInfo &RegClassInfo; bool attemptReassignmentsToAGPR(SmallSetVector &InterferingRegs, @@ -59,10 +64,11 @@ class AMDGPURewriteAGPRCopyMFMAImpl { public: AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM, LiveRegMatrix &LRM, LiveIntervals &LIS, + LiveStacks &LSS, const RegisterClassInfo &RegClassInfo) : MF(MF), ST(MF.getSubtarget()), TII(*ST.getInstrInfo()), TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM), - LIS(LIS), RegClassInfo(RegClassInfo) {} + LIS(LIS), LSS(LSS), RegClassInfo(RegClassInfo) {} bool isRewriteCandidate(const MachineInstr &MI) const { return TII.isMAI(MI) && AMDGPU::getMFMASrcCVDstAGPROp(MI.getOpcode()) != -1; @@ -103,6 +109,22 @@ class AMDGPURewriteAGPRCopyMFMAImpl { bool tryFoldCopiesToAGPR(Register VReg, MCRegister AssignedAGPR) const; bool tryFoldCopiesFromAGPR(Register VReg, MCRegister AssignedAGPR) const; + + /// Replace spill instruction \p SpillMI which loads/stores from/to \p SpillFI + /// with a COPY to the replacement register value \p VReg. + void replaceSpillWithCopyToVReg(MachineInstr &SpillMI, int SpillFI, + Register VReg) const; + + /// Create a map from frame index to use instructions for spills. If a use of + /// the frame index does not consist only of spill instructions, it will not + /// be included in the map. + void collectSpillIndexUses(ArrayRef StackIntervals, + SpillReferenceMap &Map) const; + + /// Attempt to unspill VGPRs by finding a free register and replacing the + /// spill instructions with copies. + void eliminateSpillsOfReassignedVGPRs() const; + bool run(MachineFunction &MF) const; }; @@ -391,6 +413,133 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::tryFoldCopiesFromAGPR( return MadeChange; } +void AMDGPURewriteAGPRCopyMFMAImpl::replaceSpillWithCopyToVReg( + MachineInstr &SpillMI, int SpillFI, Register VReg) const { + const DebugLoc &DL = SpillMI.getDebugLoc(); + MachineBasicBlock &MBB = *SpillMI.getParent(); + MachineInstr *NewCopy; + if (SpillMI.mayStore()) { + NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY), VReg) + .add(SpillMI.getOperand(0)); + } else { + NewCopy = BuildMI(MBB, SpillMI, DL, TII.get(TargetOpcode::COPY)) + .add(SpillMI.getOperand(0)) + .addReg(VReg); + } + + LIS.ReplaceMachineInstrInMaps(SpillMI, *NewCopy); + SpillMI.eraseFromParent(); +} + +void AMDGPURewriteAGPRCopyMFMAImpl::collectSpillIndexUses( + ArrayRef StackIntervals, SpillReferenceMap &Map) const { + + SmallSet NeededFrameIndexes; + for (const LiveInterval *LI : StackIntervals) + NeededFrameIndexes.insert(LI->reg().stackSlotIndex()); + + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + for (MachineOperand &MO : MI.operands()) { + if (!MO.isFI() || !NeededFrameIndexes.count(MO.getIndex())) + continue; + + SmallVector &References = Map[MO.getIndex()]; + if (TII.isVGPRSpill(MI)) { + References.push_back(&MI); + break; + } + + // Verify this was really a spill instruction, if it's not just ignore + // all uses. + + // TODO: This should probably be verifier enforced. + NeededFrameIndexes.erase(MO.getIndex()); + Map.erase(MO.getIndex()); + } + } + } +} + +void AMDGPURewriteAGPRCopyMFMAImpl::eliminateSpillsOfReassignedVGPRs() const { + unsigned NumSlots = LSS.getNumIntervals(); + if (NumSlots == 0) + return; + + MachineFrameInfo &MFI = MF.getFrameInfo(); + + SmallVector StackIntervals; + StackIntervals.reserve(NumSlots); + + for (auto I = LSS.begin(), E = LSS.end(); I != E; ++I) { + int Slot = I->first; + if (!MFI.isSpillSlotObjectIndex(Slot) || MFI.isDeadObjectIndex(Slot)) + continue; + + LiveInterval &LI = I->second; + const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot); + if (TRI.hasVGPRs(RC)) + StackIntervals.push_back(&LI); + } + + /// Sort heaviest intervals first to prioritize their unspilling + sort(StackIntervals, [](const LiveInterval *A, const LiveInterval *B) { + return A->weight() > B->weight(); + }); + + // FIXME: The APIs for dealing with the LiveInterval of a frame index are + // cumbersome. LiveStacks owns its LiveIntervals which refer to stack + // slots. We cannot use the usual LiveRegMatrix::assign and unassign on these, + // and must create a substitute virtual register to do so. This makes + // incremental updating here difficult; we need to actually perform the IR + // mutation to get the new vreg references in place to compute the register + // LiveInterval to perform an assignment to track the new interference + // correctly, and we can't simply migrate the LiveInterval we already have. + // + // To avoid walking through the entire function for each index, pre-collect + // all the instructions slot referencess. + + DenseMap> SpillSlotReferences; + collectSpillIndexUses(StackIntervals, SpillSlotReferences); + + for (LiveInterval *LI : StackIntervals) { + int Slot = LI->reg().stackSlotIndex(); + auto SpillReferences = SpillSlotReferences.find(Slot); + if (SpillReferences == SpillSlotReferences.end()) + continue; + + const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot); + + LLVM_DEBUG(dbgs() << "Trying to eliminate " << printReg(Slot, &TRI) + << " by reassigning\n"); + + ArrayRef AllocOrder = RegClassInfo.getOrder(RC); + + for (MCPhysReg PhysReg : AllocOrder) { + if (LRM.checkInterference(*LI, PhysReg) != LiveRegMatrix::IK_Free) + continue; + + LLVM_DEBUG(dbgs() << "Reassigning " << *LI << " to " + << printReg(PhysReg, &TRI) << '\n'); + + const TargetRegisterClass *RC = LSS.getIntervalRegClass(Slot); + Register NewVReg = MRI.createVirtualRegister(RC); + + for (MachineInstr *SpillMI : SpillReferences->second) + replaceSpillWithCopyToVReg(*SpillMI, Slot, NewVReg); + + // TODO: We should be able to transfer the information from the stack + // slot's LiveInterval without recomputing from scratch with the + // replacement vreg uses. + LiveInterval &NewLI = LIS.createAndComputeVirtRegInterval(NewVReg); + VRM.grow(); + LRM.assign(NewLI, PhysReg); + MFI.RemoveStackObject(Slot); + break; + } + } +} + bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const { // This only applies on subtargets that have a configurable AGPR vs. VGPR // allocation. @@ -417,6 +566,12 @@ bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const { MadeChange = true; } + // If we've successfully rewritten some MFMAs, we've alleviated some VGPR + // pressure. See if we can eliminate some spills now that those registers are + // more available. + if (MadeChange) + eliminateSpillsOfReassignedVGPRs(); + return MadeChange; } @@ -440,10 +595,13 @@ class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass { AU.addRequired(); AU.addRequired(); AU.addRequired(); + AU.addRequired(); AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); + AU.addPreserved(); + AU.setPreservesAll(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -456,6 +614,7 @@ INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE, INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass) INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy) INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy) +INITIALIZE_PASS_DEPENDENCY(LiveStacksWrapperLegacy) INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE, "AMDGPU Rewrite AGPR-Copy-MFMA", false, false) @@ -474,8 +633,8 @@ bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction( auto &VRM = getAnalysis().getVRM(); auto &LRM = getAnalysis().getLRM(); auto &LIS = getAnalysis().getLIS(); - - AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo); + auto &LSS = getAnalysis().getLS(); + AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo); return Impl.run(MF); } @@ -485,13 +644,15 @@ AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF, VirtRegMap &VRM = MFAM.getResult(MF); LiveRegMatrix &LRM = MFAM.getResult(MF); LiveIntervals &LIS = MFAM.getResult(MF); + LiveStacks &LSS = MFAM.getResult(MF); RegisterClassInfo RegClassInfo; RegClassInfo.runOnMachineFunction(MF); - AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, RegClassInfo); + AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS, LSS, RegClassInfo); if (!Impl.run(MF)) return PreservedAnalyses::all(); auto PA = getMachineFunctionPassPreservedAnalyses(); PA.preserveSet(); + PA.preserve(); return PA; } diff --git a/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll b/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll index 122d46b39ff32..3b9e3a00036af 100644 --- a/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll +++ b/llvm/test/CodeGen/AMDGPU/unspill-vgpr-after-rewrite-vgpr-mfma.ll @@ -101,13 +101,8 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg, ; CHECK-NEXT: v_accvgpr_read_b32 v2, a2 ; CHECK-NEXT: v_accvgpr_read_b32 v3, a3 ; CHECK-NEXT: ;;#ASMSTART -; CHECK-NEXT: ; def v[0:3] +; CHECK-NEXT: ; def v[10:13] ; CHECK-NEXT: ;;#ASMEND -; CHECK-NEXT: buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill -; CHECK-NEXT: s_nop 0 -; CHECK-NEXT: buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill -; CHECK-NEXT: buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill -; CHECK-NEXT: buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill ; CHECK-NEXT: v_mov_b32_e32 v0, 0 ; CHECK-NEXT: ;;#ASMSTART ; CHECK-NEXT: ; def a[0:31] @@ -147,12 +142,7 @@ define void @eliminate_spill_after_mfma_rewrite(i32 %x, i32 %y, <4 x i32> %arg, ; CHECK-NEXT: s_waitcnt vmcnt(0) ; CHECK-NEXT: global_store_dwordx4 v0, a[36:39], s[16:17] offset:16 ; CHECK-NEXT: s_waitcnt vmcnt(0) -; CHECK-NEXT: buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload -; CHECK-NEXT: s_waitcnt vmcnt(0) -; CHECK-NEXT: global_store_dwordx4 v0, v[2:5], s[16:17] +; CHECK-NEXT: global_store_dwordx4 v0, v[10:13], s[16:17] ; CHECK-NEXT: s_waitcnt vmcnt(0) ; CHECK-NEXT: buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload ; CHECK-NEXT: buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload @@ -311,26 +301,16 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar ; CHECK-NEXT: v_accvgpr_write_b32 a33, v1 ; CHECK-NEXT: v_accvgpr_write_b32 a32, v0 ; CHECK-NEXT: v_accvgpr_read_b32 v7, a3 +; CHECK-NEXT: v_mov_b32_e32 v0, 0 ; CHECK-NEXT: v_accvgpr_read_b32 v6, a2 ; CHECK-NEXT: v_accvgpr_read_b32 v5, a1 ; CHECK-NEXT: v_accvgpr_read_b32 v4, a0 ; CHECK-NEXT: ;;#ASMSTART -; CHECK-NEXT: ; def v[0:3] +; CHECK-NEXT: ; def v[14:17] ; CHECK-NEXT: ;;#ASMEND -; CHECK-NEXT: buffer_store_dword v0, off, s[0:3], s32 offset:192 ; 4-byte Folded Spill -; CHECK-NEXT: s_nop 0 -; CHECK-NEXT: buffer_store_dword v1, off, s[0:3], s32 offset:196 ; 4-byte Folded Spill -; CHECK-NEXT: buffer_store_dword v2, off, s[0:3], s32 offset:200 ; 4-byte Folded Spill -; CHECK-NEXT: buffer_store_dword v3, off, s[0:3], s32 offset:204 ; 4-byte Folded Spill ; CHECK-NEXT: ;;#ASMSTART -; CHECK-NEXT: ; def v[0:3] +; CHECK-NEXT: ; def v[10:13] ; CHECK-NEXT: ;;#ASMEND -; CHECK-NEXT: buffer_store_dword v0, off, s[0:3], s32 offset:208 ; 4-byte Folded Spill -; CHECK-NEXT: s_nop 0 -; CHECK-NEXT: buffer_store_dword v1, off, s[0:3], s32 offset:212 ; 4-byte Folded Spill -; CHECK-NEXT: buffer_store_dword v2, off, s[0:3], s32 offset:216 ; 4-byte Folded Spill -; CHECK-NEXT: buffer_store_dword v3, off, s[0:3], s32 offset:220 ; 4-byte Folded Spill -; CHECK-NEXT: v_mov_b32_e32 v0, 0 ; CHECK-NEXT: ;;#ASMSTART ; CHECK-NEXT: ; def a[0:31] ; CHECK-NEXT: ;;#ASMEND @@ -369,19 +349,9 @@ define void @eliminate_spill_after_mfma_rewrite_x2(i32 %x, i32 %y, <4 x i32> %ar ; CHECK-NEXT: s_waitcnt vmcnt(0) ; CHECK-NEXT: global_store_dwordx4 v0, a[36:39], s[16:17] offset:16 ; CHECK-NEXT: s_waitcnt vmcnt(0) -; CHECK-NEXT: buffer_load_dword v2, off, s[0:3], s32 offset:192 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v3, off, s[0:3], s32 offset:196 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v4, off, s[0:3], s32 offset:200 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v5, off, s[0:3], s32 offset:204 ; 4-byte Folded Reload -; CHECK-NEXT: s_waitcnt vmcnt(0) -; CHECK-NEXT: global_store_dwordx4 v0, v[2:5], s[16:17] -; CHECK-NEXT: s_waitcnt vmcnt(0) -; CHECK-NEXT: buffer_load_dword v2, off, s[0:3], s32 offset:208 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v3, off, s[0:3], s32 offset:212 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v4, off, s[0:3], s32 offset:216 ; 4-byte Folded Reload -; CHECK-NEXT: buffer_load_dword v5, off, s[0:3], s32 offset:220 ; 4-byte Folded Reload +; CHECK-NEXT: global_store_dwordx4 v0, v[14:17], s[16:17] ; CHECK-NEXT: s_waitcnt vmcnt(0) -; CHECK-NEXT: global_store_dwordx4 v0, v[2:5], s[16:17] +; CHECK-NEXT: global_store_dwordx4 v0, v[10:13], s[16:17] ; CHECK-NEXT: s_waitcnt vmcnt(0) ; CHECK-NEXT: buffer_load_dword a63, off, s[0:3], s32 ; 4-byte Folded Reload ; CHECK-NEXT: buffer_load_dword a62, off, s[0:3], s32 offset:4 ; 4-byte Folded Reload