Skip to content

Commit fbe1c06

Browse files
committed
[LLVM][Uniformity] Improve detection of uniform registers
The MachineUA now queries the target to determine if a given register holds a uniform value. This is determined using the corresponding register bank if available, or by a combination of the register class and value type. This assumes that the target is optimizing for performance by choosing registers, and the target is responsible for any mismatch with the inferred uniformity. For example, on AMDGPU, an SGPR is now treated as uniform, except if the register bank is VCC (i.e., the register holds a wave-wide vector of 1-bit values) or equivalently if it has a value type of s1. - This does not always work with inline asm, where the register bank or the value type might not be present. We assume that the SGPR is uniform, because it is not expected to be s1 in the vast majority of cases. - The pseudo branch instruction SI_LOOP is now hard-coded to be always divergent, although its condition is an SGPR. Reviewed By: arsenm Differential Revision: https://reviews.llvm.org/D150438
1 parent fb7c237 commit fbe1c06

18 files changed

+461
-48
lines changed

llvm/include/llvm/ADT/GenericUniformityImpl.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,7 @@ template <typename ContextT> class GenericUniformityAnalysisImpl {
359359
/// \returns Whether the tracked divergence state of \p DivVal changed.
360360
bool markDivergent(const InstructionT &I);
361361
bool markDivergent(ConstValueRefT DivVal);
362-
bool markDefsDivergent(const InstructionT &Instr,
363-
bool AllDefsDivergent = true);
362+
bool markDefsDivergent(const InstructionT &Instr);
364363

365364
/// \brief Propagate divergence to all instructions in the region.
366365
/// Divergence is seeded by calls to \p markDivergent.

llvm/include/llvm/CodeGen/RegisterBankInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,11 @@ class RegisterBankInfo {
587587
/// Get the total number of register banks.
588588
unsigned getNumRegBanks() const { return NumRegBanks; }
589589

590+
/// Returns true if the register bank is considered divergent.
591+
virtual bool isDivergentRegBank(const RegisterBank *RB) const {
592+
return false;
593+
}
594+
590595
/// Get a register bank that covers \p RC.
591596
///
592597
/// \pre \p RC is a user-defined register class (as opposed as one

llvm/include/llvm/CodeGen/TargetRegisterInfo.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "llvm/ADT/StringRef.h"
2121
#include "llvm/ADT/iterator_range.h"
2222
#include "llvm/CodeGen/MachineBasicBlock.h"
23-
#include "llvm/CodeGen/MachineValueType.h"
23+
#include "llvm/CodeGen/RegisterBank.h"
2424
#include "llvm/IR/CallingConv.h"
2525
#include "llvm/MC/LaneBitmask.h"
2626
#include "llvm/MC/MCRegisterInfo.h"
@@ -557,6 +557,12 @@ class TargetRegisterInfo : public MCRegisterInfo {
557557
return false;
558558
}
559559

560+
/// Returns true if the register is considered uniform.
561+
virtual bool isUniformReg(const MachineRegisterInfo &MRI,
562+
const RegisterBankInfo &RBI, Register Reg) const {
563+
return false;
564+
}
565+
560566
/// Physical registers that may be modified within a function but are
561567
/// guaranteed to be restored before any uses. This is useful for targets that
562568
/// have call sequences where a GOT register may be updated by the caller

llvm/lib/Analysis/UniformityAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
2626

2727
template <>
2828
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29-
const Instruction &Instr, bool AllDefsDivergent) {
29+
const Instruction &Instr) {
3030
return markDivergent(&Instr);
3131
}
3232

llvm/lib/CodeGen/MachineUniformityAnalysis.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,19 @@ bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
3131

3232
template <>
3333
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
34-
const MachineInstr &Instr, bool AllDefsDivergent) {
34+
const MachineInstr &Instr) {
3535
bool insertedDivergent = false;
3636
const auto &MRI = F.getRegInfo();
37+
const auto &RBI = *F.getSubtarget().getRegBankInfo();
3738
const auto &TRI = *MRI.getTargetRegisterInfo();
3839
for (auto &op : Instr.operands()) {
3940
if (!op.isReg() || !op.isDef())
4041
continue;
4142
if (!op.getReg().isVirtual())
4243
continue;
4344
assert(!op.getSubReg());
44-
if (!AllDefsDivergent) {
45-
auto *RC = MRI.getRegClassOrNull(op.getReg());
46-
if (RC && !TRI.isDivergentRegClass(RC))
47-
continue;
48-
}
45+
if (TRI.isUniformReg(MRI, RBI, op.getReg()))
46+
continue;
4947
insertedDivergent |= markDivergent(op.getReg());
5048
}
5149
return insertedDivergent;
@@ -64,7 +62,8 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
6462
}
6563

6664
if (uniformity == InstructionUniformity::NeverUniform) {
67-
markDefsDivergent(instr, /* AllDefsDivergent = */ false);
65+
if (markDivergent(instr))
66+
Worklist.push_back(&instr);
6867
}
6968
}
7069
}

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ static bool isVectorRegisterBank(const RegisterBank &Bank) {
215215
return BankID == AMDGPU::VGPRRegBankID || BankID == AMDGPU::AGPRRegBankID;
216216
}
217217

218+
bool AMDGPURegisterBankInfo::isDivergentRegBank(const RegisterBank *RB) const {
219+
return RB != &AMDGPU::SGPRRegBank;
220+
}
221+
218222
unsigned AMDGPURegisterBankInfo::copyCost(const RegisterBank &Dst,
219223
const RegisterBank &Src,
220224
unsigned Size) const {

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ class AMDGPURegisterBankInfo final : public AMDGPUGenRegisterBankInfo {
168168
public:
169169
AMDGPURegisterBankInfo(const GCNSubtarget &STI);
170170

171+
bool isDivergentRegBank(const RegisterBank *RB) const override;
172+
171173
unsigned copyCost(const RegisterBank &A, const RegisterBank &B,
172174
unsigned Size) const override;
173175

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def SI_LOOP : CFPseudoInstSI <
420420
let Size = 8;
421421
let isBranch = 1;
422422
let hasSideEffects = 1;
423+
let IsNeverUniform = 1;
423424
}
424425

425426
} // End isTerminator = 1

llvm/lib/Target/AMDGPU/SIRegisterInfo.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,6 +2865,16 @@ MCRegister SIRegisterInfo::findUnusedRegister(const MachineRegisterInfo &MRI,
28652865
return MCRegister();
28662866
}
28672867

2868+
bool SIRegisterInfo::isUniformReg(const MachineRegisterInfo &MRI,
2869+
const RegisterBankInfo &RBI,
2870+
Register Reg) const {
2871+
auto *RB = RBI.getRegBank(Reg, MRI, *MRI.getTargetRegisterInfo());
2872+
if (!RB)
2873+
return false;
2874+
2875+
return !RBI.isDivergentRegBank(RB);
2876+
}
2877+
28682878
ArrayRef<int16_t> SIRegisterInfo::getRegSplitParts(const TargetRegisterClass *RC,
28692879
unsigned EltSize) const {
28702880
const unsigned RegBitWidth = AMDGPU::getRegBitWidth(*RC->MC);

llvm/lib/Target/AMDGPU/SIRegisterInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,17 @@ class SIRegisterInfo final : public AMDGPUGenRegisterInfo {
286286
return isVGPR(MRI, Reg) || isAGPR(MRI, Reg);
287287
}
288288

289+
// FIXME: SGPRs are assumed to be uniform, but this is not true for i1 SGPRs
290+
// (such as VCC) which hold a wave-wide vector of boolean values. Examining
291+
// just the register class is not suffcient; it needs to be combined with a
292+
// value type. The next predicate isUniformReg() does this correctly.
289293
bool isDivergentRegClass(const TargetRegisterClass *RC) const override {
290294
return !isSGPRClass(RC);
291295
}
292296

297+
bool isUniformReg(const MachineRegisterInfo &MRI, const RegisterBankInfo &RBI,
298+
Register Reg) const override;
299+
293300
ArrayRef<int16_t> getRegSplitParts(const TargetRegisterClass *RC,
294301
unsigned EltSize) const;
295302

0 commit comments

Comments
 (0)