-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[SPIRV] Add bfloat support #155645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SPIRV] Add bfloat support #155645
Conversation
@llvm/pr-subscribers-llvm-globalisel @llvm/pr-subscribers-backend-spir-v Author: None (YixingZhang007) ChangesFull diff: https://github.com/llvm/llvm-project/pull/155645.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
NoUSWrap = 1 << 20, // Instruction supports geps
// no unsigned signed wrap.
SameSign = 1 << 21, // Both operands have the same sign.
- InBounds = 1 << 22 // Pointer arithmetic remains inbounds.
+ InBounds = 1 << 22, // Pointer arithmetic remains inbounds.
// Implies NoUSWrap.
+ BFloat16 = 1 << 23 // Instruction with bf16 type
};
private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
- if (containsBF16Type(U))
- return false;
+ // if (containsBF16Type(U))
+ // return false;
const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (isa<FPMathOperator>(CI))
MIB->copyIRFlags(CI);
+ // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+ if (containsBF16Type(U)) {
+ MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+ }
+
for (const auto &Arg : enumerate(CI.args())) {
// If this is required to be an immediate, don't materialize it in a
// register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
if (I.getMetadata(LLVMContext::MD_unpredictable))
MIFlags |= MachineInstr::MIFlag::Unpredictable;
+ if (I.getType()->isBFloatTy())
+ MIFlags |= MachineInstr::MIFlag::BFloat16;
+
return MIFlags;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
- SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+ SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+ {"SPV_KHR_bfloat16",
+ SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
+ if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+ Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+ Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+ }
break;
}
case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
|
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp,h -- llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
View the diff from clang-format here.diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index bd13a3bae..993de9e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -148,8 +148,7 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
- {"SPV_KHR_bfloat16",
- SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
+ {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e021920f4..c746a0bf2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,7 +195,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
- MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
+ MachineIRBuilder &MIRBuilder,
+ uint32_t Fptype) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
@@ -1682,7 +1683,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
const MachineInstr *NewMI = nullptr;
- if(SPIRVOPcode == SPIRV::OpTypeFloat) {
+ if (SPIRVOPcode == SPIRV::OpTypeFloat) {
if (Ty->isBFloatTy()) {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
@@ -1702,11 +1703,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
}
} else {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
- return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
- MIRBuilder.getDL(), TII.get(SPIRVOPcode))
- .addDef(createTypeVReg(CurMF->getRegInfo()))
- .addImm(BitWidth);
- });
+ return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+ MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+ .addDef(createTypeVReg(CurMF->getRegInfo()))
+ .addImm(BitWidth);
+ });
}
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a2a2e7b28..959a65b2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,7 +436,8 @@ private:
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
- SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
+ SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+ uint32_t Fptype);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 44024bde3..df5f5caf7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1266,8 +1266,8 @@ void addInstrRequirements(const MachineInstr &MI,
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
(TypeDef->getOperand(1).getImm() == 16)) {
- Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
- }
+ Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+ }
break;
}
case SPIRV::OpTypeFloat: {
@@ -1276,7 +1276,7 @@ void addInstrRequirements(const MachineInstr &MI,
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16) {
- if(Fptype == 1) {
+ if (Fptype == 1) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
|
No description provided.