Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 100 additions & 77 deletions llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,18 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
bool expandCALL_BTI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI);
bool expandStoreSwiftAsyncContext(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
MachineBasicBlock *
expandCommitOrRestoreZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
struct ConditionalBlocks {
MachineBasicBlock &CondBB;
MachineBasicBlock &EndBB;
};
ConditionalBlocks expandConditionalPseudo(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
DebugLoc DL,
MachineInstrBuilder &Branch);
MachineBasicBlock *expandRestoreZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
MachineBasicBlock *expandCommitZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
MachineBasicBlock *expandCondSMToggle(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI);
};
Expand Down Expand Up @@ -991,72 +1000,97 @@ bool AArch64ExpandPseudo::expandStoreSwiftAsyncContext(
return true;
}

static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;

MachineBasicBlock *AArch64ExpandPseudo::expandCommitOrRestoreZASave(
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
MachineInstr &MI = *MBBI;
bool IsRestoreZA = MI.getOpcode() == AArch64::RestoreZAPseudo;
assert((MI.getOpcode() == AArch64::RestoreZAPseudo ||
MI.getOpcode() == AArch64::CommitZASavePseudo) &&
"Expected ZA commit or restore");
AArch64ExpandPseudo::ConditionalBlocks
AArch64ExpandPseudo::expandConditionalPseudo(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
DebugLoc DL,
MachineInstrBuilder &Branch) {
assert((std::next(MBBI) != MBB.end() ||
MI.getParent()->successors().begin() !=
MI.getParent()->successors().end()) &&
"Unexpected unreachable in block that restores ZA");

// Compare TPIDR2_EL0 value against 0.
DebugLoc DL = MI.getDebugLoc();
MachineInstrBuilder Branch =
BuildMI(MBB, MBBI, DL,
TII->get(IsRestoreZA ? AArch64::CBZX : AArch64::CBNZX))
.add(MI.getOperand(0));
MBB.successors().begin() != MBB.successors().end()) &&
"Unexpected unreachable in block");

// Split MBB and create two new blocks:
// - MBB now contains all instructions before RestoreZAPseudo.
// - SMBB contains the [Commit|RestoreZA]Pseudo instruction only.
// - EndBB contains all instructions after [Commit|RestoreZA]Pseudo.
// - MBB now contains all instructions before the conditional pseudo.
// - CondBB contains the conditional pseudo instruction only.
// - EndBB contains all instructions after the conditional pseudo.
MachineInstr &PrevMI = *std::prev(MBBI);
MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
? *SMBB->successors().begin()
: SMBB->splitAt(MI, /*UpdateLiveIns*/ true);

// Add the SMBB label to the CB[N]Z instruction & create a branch to EndBB.
Branch.addMBB(SMBB);
MachineBasicBlock *CondBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
MachineBasicBlock *EndBB =
std::next(MBBI) == CondBB->end()
? *CondBB->successors().begin()
: CondBB->splitAt(*MBBI, /*UpdateLiveIns*/ true);

// Add the SMBB label to the branch instruction & create a branch to EndBB.
Branch.addMBB(CondBB);
BuildMI(&MBB, DL, TII->get(AArch64::B))
.addMBB(EndBB);
MBB.addSuccessor(EndBB);

// Create branch from CondBB to EndBB. Users of this helper should insert new
// instructions at CondBB.back() -- i.e. before the branch.
BuildMI(CondBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
return {*CondBB, *EndBB};
}

MachineBasicBlock *
AArch64ExpandPseudo::expandRestoreZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
MachineInstr &MI = *MBBI;
DebugLoc DL = MI.getDebugLoc();

// Compare TPIDR2_EL0 against 0. Restore ZA if TPIDR2_EL0 is zero.
MachineInstrBuilder Branch =
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBZX)).add(MI.getOperand(0));

auto [CondBB, EndBB] = expandConditionalPseudo(MBB, MBBI, DL, Branch);
// Replace the pseudo with a call (BL).
MachineInstrBuilder MIB =
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::BL));
BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::BL));
// Copy operands (mainly the regmask) from the pseudo.
for (unsigned I = 2; I < MI.getNumOperands(); ++I)
MIB.add(MI.getOperand(I));
// Mark the TPIDR2 block pointer (X0) as an implicit use.
MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);

if (IsRestoreZA) {
// Mark the TPIDR2 block pointer (X0) as an implicit use.
MIB.addReg(MI.getOperand(1).getReg(), RegState::Implicit);
} else /*CommitZA*/ {
MI.eraseFromParent();
return &EndBB;
}

static constexpr unsigned ZERO_ALL_ZA_MASK = 0b11111111;

MachineBasicBlock *
AArch64ExpandPseudo::expandCommitZASave(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI) {
MachineInstr &MI = *MBBI;
DebugLoc DL = MI.getDebugLoc();

// Compare TPIDR2_EL0 against 0. Commit ZA if TPIDR2_EL0 is non-zero.
MachineInstrBuilder Branch =
BuildMI(MBB, MBBI, DL, TII->get(AArch64::CBNZX)).add(MI.getOperand(0));

auto [CondBB, EndBB] = expandConditionalPseudo(MBB, MBBI, DL, Branch);
// Replace the pseudo with a call (BL).
MachineInstrBuilder MIB =
BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::BL));
// Copy operands (mainly the regmask) from the pseudo.
for (unsigned I = 2; I < MI.getNumOperands(); ++I)
MIB.add(MI.getOperand(I));
// Clear TPIDR2_EL0.
BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::MSR))
.addImm(AArch64SysReg::TPIDR2_EL0)
.addReg(AArch64::XZR);
bool ZeroZA = MI.getOperand(1).getImm() != 0;
if (ZeroZA) {
[[maybe_unused]] auto *TRI =
MBB.getParent()->getSubtarget().getRegisterInfo();
// Clear TPIDR2_EL0.
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::MSR))
.addImm(AArch64SysReg::TPIDR2_EL0)
.addReg(AArch64::XZR);
bool ZeroZA = MI.getOperand(1).getImm() != 0;
if (ZeroZA) {
assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!");
BuildMI(*SMBB, SMBB->end(), DL, TII->get(AArch64::ZERO_M))
.addImm(ZERO_ALL_ZA_MASK)
.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
}
assert(MI.definesRegister(AArch64::ZAB0, TRI) && "should define ZA!");
BuildMI(CondBB, CondBB.back(), DL, TII->get(AArch64::ZERO_M))
.addImm(ZERO_ALL_ZA_MASK)
.addDef(AArch64::ZAB0, RegState::ImplicitDefine);
}

BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
MI.eraseFromParent();
return EndBB;
return &EndBB;
}

MachineBasicBlock *
Expand Down Expand Up @@ -1130,24 +1164,9 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
MachineInstrBuilder Tbx =
BuildMI(MBB, MBBI, DL, TII->get(Opc)).addReg(SMReg32).addImm(0);

// Split MBB and create two new blocks:
// - MBB now contains all instructions before MSRcond_pstatesvcrImm1.
// - SMBB contains the MSRcond_pstatesvcrImm1 instruction only.
// - EndBB contains all instructions after MSRcond_pstatesvcrImm1.
MachineInstr &PrevMI = *std::prev(MBBI);
MachineBasicBlock *SMBB = MBB.splitAt(PrevMI, /*UpdateLiveIns*/ true);
MachineBasicBlock *EndBB = std::next(MI.getIterator()) == SMBB->end()
? *SMBB->successors().begin()
: SMBB->splitAt(MI, /*UpdateLiveIns*/ true);

// Add the SMBB label to the TB[N]Z instruction & create a branch to EndBB.
Tbx.addMBB(SMBB);
BuildMI(&MBB, DL, TII->get(AArch64::B))
.addMBB(EndBB);
MBB.addSuccessor(EndBB);

auto [CondBB, EndBB] = expandConditionalPseudo(MBB, MBBI, DL, Tbx);
// Create the SMSTART/SMSTOP (MSRpstatesvcrImm1) instruction in SMBB.
MachineInstrBuilder MIB = BuildMI(*SMBB, SMBB->begin(), MI.getDebugLoc(),
MachineInstrBuilder MIB = BuildMI(CondBB, CondBB.back(), MI.getDebugLoc(),
TII->get(AArch64::MSRpstatesvcrImm1));
// Copy all but the second and third operands of MSRcond_pstatesvcrImm1 (as
// these contain the CopyFromReg for the first argument and the flag to
Expand All @@ -1157,10 +1176,8 @@ AArch64ExpandPseudo::expandCondSMToggle(MachineBasicBlock &MBB,
for (unsigned i = 4; i < MI.getNumOperands(); ++i)
MIB.add(MI.getOperand(i));

BuildMI(SMBB, DL, TII->get(AArch64::B)).addMBB(EndBB);

MI.eraseFromParent();
return EndBB;
return &EndBB;
}

bool AArch64ExpandPseudo::expandMultiVecPseudo(
Expand Down Expand Up @@ -1673,15 +1690,21 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
return expandCALL_BTI(MBB, MBBI);
case AArch64::StoreSwiftAsyncContext:
return expandStoreSwiftAsyncContext(MBB, MBBI);
case AArch64::RestoreZAPseudo:
case AArch64::CommitZASavePseudo:
case AArch64::RestoreZAPseudo: {
auto *NewMBB = expandCommitOrRestoreZASave(MBB, MBBI);
if (NewMBB != &MBB)
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
}
case AArch64::MSRpstatePseudo: {
auto *NewMBB = expandCondSMToggle(MBB, MBBI);
auto *NewMBB = [&] {
switch (Opcode) {
case AArch64::RestoreZAPseudo:
return expandRestoreZASave(MBB, MBBI);
case AArch64::CommitZASavePseudo:
return expandCommitZASave(MBB, MBBI);
case AArch64::MSRpstatePseudo:
return expandCondSMToggle(MBB, MBBI);
default:
llvm_unreachable("Unexpected conditional pseudo!");
}
}();
if (NewMBB != &MBB)
NextMBBI = MBB.end(); // The NextMBBI iterator is invalidated.
return true;
Expand Down