Skip to content

Commit 018dc1b

Browse files
authored
[AMDGPU] Tail call support for whole wave functions (#145860)
Support tail calls to whole wave functions (trivial) and from whole wave functions (slightly more involved because we need a new pseudo for the tail call return, that patches up the EXEC mask). Move the expansion of whole wave function return pseudos (regular and tail call returns) to prolog epilog insertion, since that's where we patch up the EXEC mask.
1 parent d140866 commit 018dc1b

14 files changed

+2067
-42
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7974,12 +7974,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79747974
}
79757975
case Intrinsic::amdgcn_call_whole_wave: {
79767976
TargetLowering::ArgListTy Args;
7977+
bool isTailCall = I.isTailCall();
79777978

79787979
// The first argument is the callee. Skip it when assembling the call args.
79797980
for (unsigned Idx = 1; Idx < I.arg_size(); ++Idx) {
79807981
TargetLowering::ArgListEntry Arg(getValue(I.getArgOperand(Idx)),
79817982
I.getArgOperand(Idx)->getType());
79827983
Arg.setAttributes(&I, Idx);
7984+
7985+
// If we have an explicit sret argument that is an Instruction, (i.e., it
7986+
// might point to function-local memory), we can't meaningfully tail-call.
7987+
if (Arg.IsSRet && isa<Instruction>(I.getArgOperand(Idx)))
7988+
isTailCall = false;
7989+
79837990
Args.push_back(Arg);
79847991
}
79857992

@@ -7994,7 +8001,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79948001
.setChain(getRoot())
79958002
.setCallee(CallingConv::AMDGPU_Gfx_WholeWave, I.getType(),
79968003
getValue(I.getArgOperand(0)), std::move(Args))
7997-
.setTailCall(false)
8004+
.setTailCall(isTailCall && canTailCall(I))
79988005
.setIsPreallocated(
79998006
I.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
80008007
.setConvergent(I.isConvergent())
@@ -8929,6 +8936,29 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
89298936
return Result;
89308937
}
89318938

8939+
bool SelectionDAGBuilder::canTailCall(const CallBase &CB) const {
8940+
bool isMustTailCall = CB.isMustTailCall();
8941+
8942+
// Avoid emitting tail calls in functions with the disable-tail-calls
8943+
// attribute.
8944+
const Function *Caller = CB.getParent()->getParent();
8945+
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
8946+
"true" &&
8947+
!isMustTailCall)
8948+
return false;
8949+
8950+
// We can't tail call inside a function with a swifterror argument. Lowering
8951+
// does not support this yet. It would have to move into the swifterror
8952+
// register before the call.
8953+
if (DAG.getTargetLoweringInfo().supportSwiftError() &&
8954+
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
8955+
return false;
8956+
8957+
// Check if target-independent constraints permit a tail call here.
8958+
// Target-dependent constraints are checked within TLI->LowerCallTo.
8959+
return isInTailCallPosition(CB, DAG.getTarget());
8960+
}
8961+
89328962
void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89338963
bool isTailCall, bool isMustTailCall,
89348964
const BasicBlock *EHPadBB,
@@ -8943,21 +8973,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89438973
const Value *SwiftErrorVal = nullptr;
89448974
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
89458975

8946-
if (isTailCall) {
8947-
// Avoid emitting tail calls in functions with the disable-tail-calls
8948-
// attribute.
8949-
auto *Caller = CB.getParent()->getParent();
8950-
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
8951-
"true" && !isMustTailCall)
8952-
isTailCall = false;
8953-
8954-
// We can't tail call inside a function with a swifterror argument. Lowering
8955-
// does not support this yet. It would have to move into the swifterror
8956-
// register before the call.
8957-
if (TLI.supportSwiftError() &&
8958-
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
8959-
isTailCall = false;
8960-
}
8976+
if (isTailCall)
8977+
isTailCall = canTailCall(CB);
89618978

89628979
for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
89638980
const Value *V = *I;
@@ -8997,11 +9014,6 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89979014
Args.push_back(Entry);
89989015
}
89999016

9000-
// Check if target-independent constraints permit a tail call here.
9001-
// Target-dependent constraints are checked within TLI->LowerCallTo.
9002-
if (isTailCall && !isInTailCallPosition(CB, DAG.getTarget()))
9003-
isTailCall = false;
9004-
90059017
// Disable tail calls if there is an swifterror argument. Targets have not
90069018
// been updated to support tail calls.
90079019
if (TLI.supportSwiftError() && SwiftErrorVal)

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ class SelectionDAGBuilder {
408408
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
409409
const TargetLowering::PtrAuthInfo *PAI = nullptr);
410410

411+
// Check some of the target-independent constraints for tail calls. This does
412+
// not iterate over the call arguments.
413+
bool canTailCall(const CallBase &CB) const;
414+
411415
// Lower range metadata from 0 to N to assert zext to an integer of nearest
412416
// floor power of two.
413417
SDValue lowerRangeToAssertZExt(SelectionDAG &DAG, const Instruction &I,

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,8 +1004,14 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
10041004
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
10051005
}
10061006

1007-
return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
1008-
AMDGPU::SI_TCRETURN;
1007+
if (CallerF.getFunction().getCallingConv() ==
1008+
CallingConv::AMDGPU_Gfx_WholeWave)
1009+
return AMDGPU::SI_TCRETURN_GFX_WholeWave;
1010+
1011+
if (CC == CallingConv::AMDGPU_Gfx || CC == CallingConv::AMDGPU_Gfx_WholeWave)
1012+
return AMDGPU::SI_TCRETURN_GFX;
1013+
1014+
return AMDGPU::SI_TCRETURN;
10091015
}
10101016

10111017
// Add operands to call instruction to track the callee.
@@ -1284,6 +1290,13 @@ bool AMDGPUCallLowering::lowerTailCall(
12841290
unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
12851291
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
12861292
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);
1293+
1294+
if (FuncInfo->isWholeWaveFunction())
1295+
addOriginalExecToReturn(MF, MIB);
1296+
1297+
// Keep track of the index of the next operand to be added to the call
1298+
unsigned CalleeIdx = MIB->getNumOperands();
1299+
12871300
if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
12881301
return false;
12891302

@@ -1401,7 +1414,7 @@ bool AMDGPUCallLowering::lowerTailCall(
14011414
// If we have -tailcallopt, we need to adjust the stack. We'll do the call
14021415
// sequence start and end here.
14031416
if (!IsSibCall) {
1404-
MIB->getOperand(1).setImm(FPDiff);
1417+
MIB->getOperand(CalleeIdx + 1).setImm(FPDiff);
14051418
CallSeqStart.addImm(NumBytes).addImm(0);
14061419
// End the call sequence *before* emitting the call. Normally, we would
14071420
// tidy the frame up after the call. However, here, we've laid out the
@@ -1413,16 +1426,24 @@ bool AMDGPUCallLowering::lowerTailCall(
14131426
// Now we can add the actual call instruction to the correct basic block.
14141427
MIRBuilder.insertInstr(MIB);
14151428

1429+
// If this is a whole wave tail call, we need to constrain the register for
1430+
// the original EXEC.
1431+
if (MIB->getOpcode() == AMDGPU::SI_TCRETURN_GFX_WholeWave) {
1432+
MIB->getOperand(0).setReg(
1433+
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1434+
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
1435+
}
1436+
14161437
// If Callee is a reg, since it is used by a target specific
14171438
// instruction, it must have a register class matching the
14181439
// constraint of that instruction.
14191440

14201441
// FIXME: We should define regbankselectable call instructions to handle
14211442
// divergent call targets.
1422-
if (MIB->getOperand(0).isReg()) {
1423-
MIB->getOperand(0).setReg(
1424-
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1425-
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
1443+
if (MIB->getOperand(CalleeIdx).isReg()) {
1444+
MIB->getOperand(CalleeIdx).setReg(constrainOperandRegClass(
1445+
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
1446+
MIB->getOperand(CalleeIdx), CalleeIdx));
14261447
}
14271448

14281449
MF.getFrameInfo().setHasTailCall();

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5667,6 +5667,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
56675667
NODE_NAME_CASE(CALL)
56685668
NODE_NAME_CASE(TC_RETURN)
56695669
NODE_NAME_CASE(TC_RETURN_GFX)
5670+
NODE_NAME_CASE(TC_RETURN_GFX_WholeWave)
56705671
NODE_NAME_CASE(TC_RETURN_CHAIN)
56715672
NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR)
56725673
NODE_NAME_CASE(TRAP)

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,7 @@ enum NodeType : unsigned {
418418
CALL,
419419
TC_RETURN,
420420
TC_RETURN_GFX,
421+
TC_RETURN_GFX_WholeWave,
421422
TC_RETURN_CHAIN,
422423
TC_RETURN_CHAIN_DVGPR,
423424
TRAP,

llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def AMDGPUtc_return_gfx: SDNode<"AMDGPUISD::TC_RETURN_GFX", AMDGPUTCReturnTP,
9494
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
9595
>;
9696

97+
def AMDGPUtc_return_gfx_ww: SDNode<"AMDGPUISD::TC_RETURN_GFX_WholeWave", AMDGPUTCReturnTP,
98+
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
99+
>;
100+
97101
def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN",
98102
SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>,
99103
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]

llvm/lib/Target/AMDGPU/SIFrameLowering.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,9 +1132,18 @@ void SIFrameLowering::emitCSRSpillRestores(
11321132
RestoreWWMRegisters(WWMCalleeSavedRegs);
11331133

11341134
// The original EXEC is the first operand of the return instruction.
1135-
const MachineInstr &Return = MBB.instr_back();
1136-
assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
1137-
"Unexpected return inst");
1135+
MachineInstr &Return = MBB.instr_back();
1136+
unsigned Opcode = Return.getOpcode();
1137+
switch (Opcode) {
1138+
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
1139+
Opcode = AMDGPU::SI_RETURN;
1140+
break;
1141+
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
1142+
Opcode = AMDGPU::SI_TCRETURN_GFX;
1143+
break;
1144+
default:
1145+
llvm_unreachable("Unexpected return inst");
1146+
}
11381147
Register OrigExec = Return.getOperand(0).getReg();
11391148

11401149
if (!WWMScratchRegs.empty()) {
@@ -1148,6 +1157,11 @@ void SIFrameLowering::emitCSRSpillRestores(
11481157
// Restore original EXEC.
11491158
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
11501159
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
1160+
1161+
// Drop the first operand and update the opcode.
1162+
Return.removeOperand(0);
1163+
Return.setDesc(TII->get(Opcode));
1164+
11511165
return;
11521166
}
11531167

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4278,6 +4278,11 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
42784278
break;
42794279
}
42804280

4281+
// If the caller is a whole wave function, we need to use a special opcode
4282+
// so we can patch up EXEC.
4283+
if (Info->isWholeWaveFunction())
4284+
OPC = AMDGPUISD::TC_RETURN_GFX_WholeWave;
4285+
42814286
return DAG.getNode(OPC, DL, MVT::Other, Ops);
42824287
}
42834288

@@ -6041,14 +6046,15 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
60416046
MI.eraseFromParent();
60426047
return SplitBB;
60436048
}
6049+
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
60446050
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: {
60456051
assert(MFI->isWholeWaveFunction());
60466052

60476053
// During ISel, it's difficult to propagate the original EXEC mask to use as
60486054
// an input to SI_WHOLE_WAVE_FUNC_RETURN. Set it up here instead.
60496055
MachineInstr *Setup = TII->getWholeWaveFunctionSetup(*BB->getParent());
6050-
Register OriginalExec = Setup->getOperand(0).getReg();
60516056
assert(Setup && "Couldn't find SI_SETUP_WHOLE_WAVE_FUNC");
6057+
Register OriginalExec = Setup->getOperand(0).getReg();
60526058
MF->getRegInfo().clearKillFlags(OriginalExec);
60536059
MI.getOperand(0).setReg(OriginalExec);
60546060
return BB;

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2493,7 +2493,6 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
24932493
MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64));
24942494
break;
24952495
}
2496-
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
24972496
case AMDGPU::SI_RETURN: {
24982497
const MachineFunction *MF = MBB.getParent();
24992498
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,33 @@ def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI <
692692
def : GCNPat<
693693
(AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>;
694694

695+
// Restores the previous EXEC and otherwise behaves entirely like a SI_TCRETURN.
696+
// This is used for tail calls *from* a whole wave function. Tail calls to
697+
// a whole wave function may use the usual opcodes, depending on the calling
698+
// convention of the caller.
699+
def SI_TCRETURN_GFX_WholeWave : SPseudoInstSI <
700+
(outs),
701+
(ins SReg_1:$orig_exec, Gfx_CCR_SGPR_64:$src0, unknown:$callee, i32imm:$fpdiff)> {
702+
let isCall = 1;
703+
let isTerminator = 1;
704+
let isReturn = 1;
705+
let isBarrier = 1;
706+
let UseNamedOperandTable = 1;
707+
let SchedRW = [WriteBranch];
708+
let isConvergent = 1;
709+
710+
// We're going to use custom handling to set the $orig_exec to the correct value.
711+
let usesCustomInserter = 1;
712+
}
713+
714+
// Generate a SI_TCRETURN_GFX_WholeWave pseudo with a placeholder for its
715+
// argument. It will be filled in by the custom inserter.
716+
def : GCNPat<
717+
(AMDGPUtc_return_gfx_ww i64:$src0, tglobaladdr:$callee, i32:$fpdiff),
718+
(SI_TCRETURN_GFX_WholeWave (i1 (IMPLICIT_DEF)), Gfx_CCR_SGPR_64:$src0,
719+
tglobaladdr:$callee, i32:$fpdiff)>;
720+
721+
695722
// Return for returning shaders to a shader variant epilog.
696723
def SI_RETURN_TO_EPILOG : SPseudoInstSI <
697724
(outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> {

0 commit comments

Comments
 (0)