Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,9 +2765,6 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}

bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
if (containsBF16Type(U))
return false;

const CallInst &CI = cast<CallInst>(U);
const Function *F = CI.getCalledFunction();

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
{"SPV_KHR_bfloat16",
SPIRV::Extension::Extension::SPV_KHR_bfloat16}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
25 changes: 24 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,19 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual,
ExplicitLayoutRequired, EmitIR);
TypesInProcessing.erase(Ty);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;

MachineFunction *MF = &MIRBuilder.getMF();
Register TypeReg = getSPIRVTypeID(SpirvType);

if (Ty->isFloatingPointTy()) {
if (Ty->isBFloatTy()) {
VRegFPVariantMap[MF][TypeReg] = FPVariant::BRAIN_FLOAT;
} else {
VRegFPVariantMap[MF][TypeReg] = FPVariant::IEEE_FLOAT;
}
}

VRegToTypeMap[MF][TypeReg] = SpirvType;

// TODO: We could end up with two SPIR-V types pointing to the same llvm type.
// Is that a problem?
Expand Down Expand Up @@ -2088,3 +2100,14 @@ bool SPIRVGlobalRegistry::hasBlockDecoration(SPIRVType *Type) const {
}
return false;
}

SPIRVGlobalRegistry::FPVariant SPIRVGlobalRegistry::getFPVariantForVReg(Register VReg,
const MachineFunction *MF) {
auto t = VRegFPVariantMap.find(MF ? MF : CurMF);
if (t != VRegFPVariantMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
return tt->second;
}
return FPVariant::NONE;
}
16 changes: 16 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ using SPIRVType = const MachineInstr;
using StructOffsetDecorator = std::function<void(Register)>;

class SPIRVGlobalRegistry : public SPIRVIRMapping {
public:
enum class FPVariant {
NONE,
IEEE_FLOAT,
BRAIN_FLOAT
};

private:
// Registers holding values which have types associated with them.
// Initialized upon VReg definition in IRTranslator.
// Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
Expand Down Expand Up @@ -88,6 +96,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// map of aliasing decorations to aliasing metadata
std::unordered_map<const MDNode *, MachineInstr *> AliasInstMDMap;

// Maps floating point Registers to their FPVariant (float type kind), given the MachineFunction.
DenseMap<const MachineFunction *, DenseMap<Register, FPVariant>>
VRegFPVariantMap;

// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AQ,
Expand Down Expand Up @@ -421,6 +433,10 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
// Informs about removal of the machine instruction and invalidates data
// structures referring this instruction.
void invalidateMachineInstr(MachineInstr *MI);

// Return the FPVariant of to the given floating-point regiester.
FPVariant getFPVariantForVReg(Register VReg,
const MachineFunction *MF = nullptr);

private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1261,12 +1261,30 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
case SPIRV::OpDot: {
const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(1).getReg(), MF);
if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
else if (BitWidth == 16) {
SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry();
const MachineFunction *MF = MI.getMF();
SPIRVGlobalRegistry::FPVariant FPV = GR->getFPVariantForVReg(MI.getOperand(0).getReg(), MF);
if (FPV == SPIRVGlobalRegistry::FPVariant::BRAIN_FLOAT) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
Reqs.addCapability(SPIRV::Capability::Float16);
}
}
break;
}
case SPIRV::OpTypeVector: {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
20 changes: 20 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2

define spir_kernel void @test() {
entry:
%addr1 = alloca bfloat
%addr2 = alloca <2 x bfloat>
%data1 = load bfloat, ptr %addr1
%data2 = load <2 x bfloat>, ptr %addr2
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *

; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK-SPIRV: OpDot

declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)

define spir_kernel void @test() {
entry:
%addrA = alloca <2 x bfloat>
%addrB = alloca <2 x bfloat>
%dataA = load <2 x bfloat>, ptr %addrA
%dataB = load <2 x bfloat>, ptr %addrB
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
ret void
}
Loading