diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 541269ab6bfce..8a4b1afec3d26 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2765,9 +2765,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB, } bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) - return false; - const CallInst &CI = cast(U); const Function *F = CI.getCalledFunction(); diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index e7da5504b2d58..bd13a3bae92cd 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -147,7 +147,9 @@ static const std::map> {"SPV_KHR_float_controls2", SPIRV::Extension::Extension::SPV_KHR_float_controls2}, {"SPV_INTEL_tensor_float32_conversion", - SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; + SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}, + {"SPV_KHR_bfloat16", + SPIRV::Extension::Extension::SPV_KHR_bfloat16}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp index cfe24c84941a9..61a8ba19b28e4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -1122,7 +1122,19 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, ExplicitLayoutRequired, EmitIR); TypesInProcessing.erase(Ty); - VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; + + MachineFunction *MF = &MIRBuilder.getMF(); + Register TypeReg = getSPIRVTypeID(SpirvType); + + if (Ty->isFloatingPointTy()) { + if (Ty->isBFloatTy()) { + VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT; + } else { + VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT; + } + } + + VRegToTypeMap[MF][TypeReg] = SpirvType; // TODO: We could end up with two SPIR-V types pointing to the same llvm type. // Is that a problem? @@ -2088,3 +2100,14 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const { } return false; } + +SPIRVGlobalRegistry::FPVariant SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg, + const MachineFunction *MF) { + auto t = VRegFPVariantMap.find(MF ? MF : CurMF); + if (t != VRegFPVariantMap.end()) { + auto tt = t->second.find(VReg); + if (tt != t->second.end()) + return tt->second; + } + return FPVariant::NONE; +} diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h index 7ef812828b7cc..8f9a5d3cfc68a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -29,6 +29,14 @@ using SPIRVType = const MachineInstr; using StructOffsetDecorator = std::function; class SPIRVGlobalRegistry : public SPIRVIRMapping { + public: + enum class FPVariant { + NONE, + IEEE_FLOAT, + BRAIN_FLOAT + }; + + private: // Registers holding values which have types associated with them. // Initialized upon VReg definition in IRTranslator. // Do not confuse this with DuplicatesTracker as DT maps Type* to @@ -88,6 +96,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { // map of aliasing decorations to aliasing metadata std::unordered_map AliasInstMDMap; + // Maps floating point Registers to their FPVariant (float type kind), given the MachineFunction. + DenseMap> + VRegFPVariantMap; + // Add a new OpTypeXXX instruction without checking for duplicates. SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier::AccessQualifier AQ, @@ -421,6 +433,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping { // Informs about removal of the machine instruction and invalidates data // structures referring this instruction. void invalidateMachineInstr(MachineInstr *MI); + + // Return the FPVariant of to the given floating-point regiester. + FPVariant getFPVariantForVReg(Register VReg, + const MachineFunction *MF = nullptr); private: SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index 8039cf0c432fa..653b67e464908 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1261,12 +1261,30 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::Int8); break; } + case SPIRV::OpDot: { + const MachineFunction *MF = MI.getMF(); + SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); + SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF); + if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) { + Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR); + } + break; + } case SPIRV::OpTypeFloat: { unsigned BitWidth = MI.getOperand(1).getImm(); if (BitWidth == 64) Reqs.addCapability(SPIRV::Capability::Float64); - else if (BitWidth == 16) - Reqs.addCapability(SPIRV::Capability::Float16); + else if (BitWidth == 16) { + SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); + const MachineFunction *MF = MI.getMF(); + SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF); + if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) { + Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16); + Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR); + } else { + Reqs.addCapability(SPIRV::Capability::Float16); + } + } break; } case SPIRV::OpTypeVector: { diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index d2824ee2d2caf..9d630356e8ffb 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>; defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>; defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>; defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>; +defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>; defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>; defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; +defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>; +defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>; +defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll new file mode 100644 index 0000000000000..95eb5a8348d67 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll @@ -0,0 +1,20 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} +; XFAIL: * +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate + +; CHECK-DAG: OpCapability BFloat16TypeKHR +; CHECK-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1 +; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2 + +define spir_kernel void @test() { +entry: + %addr1 = alloca bfloat + %addr2 = alloca <2 x bfloat> + %data1 = load bfloat, ptr %addr1 + %data2 = load <2 x bfloat>, ptr %addr2 + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll new file mode 100644 index 0000000000000..337eda538c196 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16_dot.ll @@ -0,0 +1,22 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %} +; XFAIL: * + +; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR +; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR +; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16" +; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0 +; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2 +; CHECK-SPIRV: OpDot + +declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>) + +define spir_kernel void @test() { +entry: + %addrA = alloca <2 x bfloat> + %addrB = alloca <2 x bfloat> + %dataA = load <2 x bfloat>, ptr %addrA + %dataB = load <2 x bfloat>, ptr %addrB + %call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB) + ret void +}