Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
{"SPV_KHR_float_controls2",
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
{"SPV_INTEL_tensor_float32_conversion",
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
{"SPV_KHR_bfloat16",
SPIRV::Extension::Extension::SPV_KHR_bfloat16}};

bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
StringRef ArgValue,
Expand Down
35 changes: 29 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,12 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder) {
MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
.addImm(Width)
.addImm(Fptype);
});
}

Expand Down Expand Up @@ -1041,8 +1042,11 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isFloatingPointTy()) {
if (Ty->isBFloatTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 1);
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder, 0);
}
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
Expand Down Expand Up @@ -1677,14 +1681,33 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
return MI;
MachineBasicBlock &DepMBB = I.getMF()->front();
MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
const MachineInstr *NewMI =
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
const MachineInstr *NewMI = nullptr;
if(SPIRVOPcode == SPIRV::OpTypeFloat) {
if (Ty->isBFloatTy()) {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(1);
});
} else {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
});
}
} else {
NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
MIRBuilder.getDL(), TII.get(SPIRVOPcode))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth);
});
}
add(Ty, false, NewMI);
return finishCreatingSPIRVType(Ty, NewMI);
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);

SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);

SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);

Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ def OpTypeVoid: Op<19, (outs TYPE:$type), (ins), "$type = OpTypeVoid">;
def OpTypeBool: Op<20, (outs TYPE:$type), (ins), "$type = OpTypeBool">;
def OpTypeInt: Op<21, (outs TYPE:$type), (ins i32imm:$width, i32imm:$signedness),
"$type = OpTypeInt $width $signedness">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width),
"$type = OpTypeFloat $width">;
def OpTypeFloat: Op<22, (outs TYPE:$type), (ins i32imm:$width, i32imm:$fpencode), "$type = OpTypeFloat $width $fpencode">;
def OpTypeVector: Op<23, (outs TYPE:$type), (ins TYPE:$compType, i32imm:$compCount),
"$type = OpTypeVector $compType $compCount">;
def OpTypeMatrix: Op<24, (outs TYPE:$type), (ins TYPE:$colType, i32imm:$colCount),
Expand Down
20 changes: 18 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1261,12 +1261,28 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::Int8);
break;
}
case SPIRV::OpDot: {
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
(TypeDef->getOperand(1).getImm() == 16)) {
Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
}
break;
}
case SPIRV::OpTypeFloat: {
unsigned BitWidth = MI.getOperand(1).getImm();
unsigned Fptype = MI.getOperand(2).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
else if (BitWidth == 16)
Reqs.addCapability(SPIRV::Capability::Float16);
else if (BitWidth == 16) {
if(Fptype == 1) {
Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
} else {
Reqs.addCapability(SPIRV::Capability::Float16);
}
}
break;
}
case SPIRV::OpTypeVector: {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
20 changes: 20 additions & 0 deletions llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_bfloat16/bfloat16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: BFloat16TypeKHR requires the following SPIR-V extension: SPV_KHR_subgroup_rotate

; CHECK-DAG: OpCapability BFloat16TypeKHR
; CHECK-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK: %[[#BFLOAT:]] = OpTypeFloat 16 1
; CHECK: %[[#]] = OpTypeVector %[[#BFLOAT]] 2

define spir_kernel void @test() {
entry:
%addr1 = alloca bfloat
%addr2 = alloca <2 x bfloat>
%data1 = load bfloat, ptr %addr1
%data2 = load <2 x bfloat>, ptr %addr2
ret void
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_KHR_bfloat16 %s -o - -filetype=obj | spirv-val %}
; XFAIL: *

; CHECK-SPIRV-DAG: OpCapability BFloat16TypeKHR
; CHECK-SPIRV-DAG: OpCapability BFloat16DotProductKHR
; CHECK-SPIRV-DAG: OpExtension "SPV_KHR_bfloat16"
; CHECK-SPIRV: %[[#BFLOAT:]] = OpTypeFloat 16 0
; CHECK-SPIRV: %[[#]] = OpTypeVector %[[#BFLOAT]] 2
; CHECK-SPIRV: OpDot

declare spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat>, <2 x bfloat>)

define spir_kernel void @test() {
entry:
%addrA = alloca <2 x bfloat>
%addrB = alloca <2 x bfloat>
%dataA = load <2 x bfloat>, ptr %addrA
%dataB = load <2 x bfloat>, ptr %addrB
%call = call spir_func bfloat @_Z3dotDv2_u6__bf16Dv2_S_(<2 x bfloat> %dataA, <2 x bfloat> %dataB)
ret void
}
Loading