Skip to content
4 changes: 2 additions & 2 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}

bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (containsBF16Type(U))
return false;
// if (containsBF16Type(U))
// return false;

const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
{"SPV_KHR_bfloat16",
SPIRV::Extension::Extension::SPV_KHR_bfloat16}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
54 changes: 47 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder) {
MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
.addImm(Width)
.addImm(Fptype);
});
}

Expand Down Expand Up @@ -1041,8 +1042,11 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isFloatingPointTy()) {
if (Ty->isBFloatTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
}
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
Expand Down Expand Up @@ -1121,9 +1125,15 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
ExplicitLayoutRequired, EmitIR);
TypesInProcessing.erase(Ty);

if(Ty->isFloatingPointTy() && Ty->isBFloatTy()){
VRegIsBFloatMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = true;
}

VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;

TypesInProcessing.erase(Ty);

// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
SPIRVToLLVMType[SpirvType] = unifyPtrType(Ty);
Expand Down Expand Up @@ -1158,6 +1168,17 @@ SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg,
return nullptr;
}

bool SPIRVGlobalRegistry::getIsBFloatForVReg(Register VReg,
const MachineFunction *MF) {
auto t = VRegToTypeMap.find(MF ? MF : CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
return tt->second;
}
return false;
}

SPIRVType *SPIRVGlobalRegistry::getResultType(Register VReg,
MachineFunction *MF) {
if (!MF)
Expand Down Expand Up @@ -1677,14 +1698,33 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
return MI;
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
const MachineInstr *NewMI = nullptr;
if(SPIRVOPcode == SPIRV::OpTypeFloat) {
if (Ty->isBFloatTy()) {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(1);
});
} else {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
});
}
} else {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth);
});
}
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// type-declaring ones).
DenseMap<const MachineFunction *, DenseMap<Register, SPIRVType *>>
VRegToTypeMap;

DenseMap<const MachineFunction *, DenseMap<Register, bool>>
VRegIsBFloatMap;

DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;

Expand Down Expand Up @@ -340,6 +343,9 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// functions.
SPIRVType *getSPIRVTypeForVReg(Register VReg,
const MachineFunction *MF = nullptr) const;

bool getIsBFloatForVReg(Register VReg,
const MachineFunction *MF = nullptr);

// Return the result type of the instruction defining the register.
SPIRVType *getResultType(Register VReg, MachineFunction *MF = nullptr);
Expand Down Expand Up @@ -436,7 +442,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);

SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);

SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);

Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
"$type = OpTypeFloat $width">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1261,12 +1261,30 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
case SPIRV::OpDot: {
const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(1).getReg(), MF);
if (isBfloat) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
const MachineFunction *MF = MI.getMF();
bool isBfloat = GR->getIsBFloatForVReg(MI.getOperand(0).getReg(), MF);
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
else if (BitWidth == 16) {
if(isBfloat) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
Reqs.addCapability(SPIRV::Capability::Float16);
}
}
break;
}
case SPIRV::OpTypeVector: {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
20 changes: 20 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2

define spir_kernel void @test() {
entry:
%addr1 = alloca bfloat
%addr2 = alloca <2 x bfloat>
%data1 = load bfloat, ptr %addr1
%data2 = load <2 x bfloat>, ptr %addr2
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *

; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK-SPIRV: OpDot

declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)

define spir_kernel void @test() {
entry:
%addrA = alloca <2 x bfloat>
%addrB = alloca <2 x bfloat>
%dataA = load <2 x bfloat>, ptr %addrA
%dataB = load <2 x bfloat>, ptr %addrB
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
ret void
}
Loading