Skip to content

Conversation

YixingZhang007
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-llvm-globalisel

@llvm/pr-subscribers-backend-spir-v

Author: None (YixingZhang007)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/155645.diff

6 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/MachineInstr.h (+2-1)
  • (modified) llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp (+7-2)
  • (modified) llvm/lib/CodeGen/MachineInstr.cpp (+3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+3-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+4)
diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h
index 10a9b1ff1411d..6f692ae32510b 100644
--- a/llvm/include/llvm/CodeGen/MachineInstr.h
+++ b/llvm/include/llvm/CodeGen/MachineInstr.h
@@ -123,8 +123,9 @@ class MachineInstr
     NoUSWrap = 1 << 20,      // Instruction supports geps
                              // no unsigned signed wrap.
     SameSign = 1 << 21,      // Both operands have the same sign.
-    InBounds = 1 << 22       // Pointer arithmetic remains inbounds.
+    InBounds = 1 << 22,       // Pointer arithmetic remains inbounds.
                              // Implies NoUSWrap.
+    BFloat16 = 1 << 23      // Instruction with bf16 type
   };
 
 private:
diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
index 541269ab6bfce..2a6b66984c8ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
@@ -2765,8 +2765,8 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
 }
 
 bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
-  if (containsBF16Type(U))
-    return false;
+  // if (containsBF16Type(U))
+  //   return false;
 
   const CallInst &CI = cast<CallInst>(U);
   const Function *F = CI.getCalledFunction();
@@ -2813,6 +2813,11 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) {
   if (isa<FPMathOperator>(CI))
     MIB->copyIRFlags(CI);
 
+  // If the spirv intrinsic contain bfloat, enable to Bfloat flag in MachineInst
+  if (containsBF16Type(U)) {
+    MIB->setFlag(MachineInstr::MIFlag::BFloat16);
+  }
+
   for (const auto &Arg : enumerate(CI.args())) {
     // If this is required to be an immediate, don't materialize it in a
     // register.
diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp
index 79047f732808a..10ff667bcb522 100644
--- a/llvm/lib/CodeGen/MachineInstr.cpp
+++ b/llvm/lib/CodeGen/MachineInstr.cpp
@@ -632,6 +632,9 @@ uint32_t MachineInstr::copyFlagsFromInstruction(const Instruction &I) {
   if (I.getMetadata(LLVMContext::MD_unpredictable))
     MIFlags |= MachineInstr::MIFlag::Unpredictable;
 
+  if (I.getType()->isBFloatTy())
+    MIFlags |= MachineInstr::MIFlag::BFloat16;
+
   return MIFlags;
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index e7da5504b2d58..bd13a3bae92cd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -147,7 +147,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
         {"SPV_KHR_float_controls2",
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
-         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
+         SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
+        {"SPV_KHR_bfloat16",
+         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..5bba5cdce3753 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1267,6 +1267,10 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16)
       Reqs.addCapability(SPIRV::Capability::Float16);
+    if(MI.getFlag(MachineInstr::MIFlag::BFloat16)) {
+      Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
+      Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
+    }
     break;
   }
   case SPIRV::OpTypeVector: {
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index d2824ee2d2caf..9d630356e8ffb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -382,6 +382,7 @@ defm SPV_INTEL_2d_block_io : ExtensionOperand<122, [EnvOpenCL]>;
 defm SPV_INTEL_int4 : ExtensionOperand<123, [EnvOpenCL]>;
 defm SPV_KHR_float_controls2 : ExtensionOperand<124, [EnvVulkan, EnvOpenCL]>;
 defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125, [EnvOpenCL]>;
+defm SPV_KHR_bfloat16 : ExtensionOperand<126, [EnvOpenCL]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -594,6 +595,9 @@ defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d
 defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
 defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
 defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
+defm BFloat16TypeKHR : CapabilityOperand<5116, 0, 0, [SPV_KHR_bfloat16], []>;
+defm BFloat16DotProductKHR : CapabilityOperand<5117, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR]>;
+defm BFloat16CooperativeMatrixKHR : CapabilityOperand<5118, 0, 0, [SPV_KHR_bfloat16], [BFloat16TypeKHR, CooperativeMatrixKHR]>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define SourceLanguage enum values and at the same time

Copy link

github-actions bot commented Aug 27, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index bd13a3bae..993de9e9f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -148,8 +148,7 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
          SPIRV::Extension::Extension::SPV_KHR_float_controls2},
         {"SPV_INTEL_tensor_float32_conversion",
          SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion},
-        {"SPV_KHR_bfloat16",
-         SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
+        {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}};
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
                                   StringRef ArgValue,
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index e021920f4..c746a0bf2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -195,7 +195,8 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(unsigned Width,
 }
 
 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
-                                               MachineIRBuilder &MIRBuilder, uint32_t Fptype) {
+                                               MachineIRBuilder &MIRBuilder,
+                                               uint32_t Fptype) {
   return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
     return MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
         .addDef(createTypeVReg(MIRBuilder))
@@ -1682,7 +1683,7 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
   MachineBasicBlock &DepMBB = I.getMF()->front();
   MachineIRBuilder MIRBuilder(DepMBB, DepMBB.getFirstNonPHI());
   const MachineInstr *NewMI = nullptr;
-  if(SPIRVOPcode == SPIRV::OpTypeFloat) {
+  if (SPIRVOPcode == SPIRV::OpTypeFloat) {
     if (Ty->isBFloatTy()) {
       NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
         return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
@@ -1702,11 +1703,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(unsigned BitWidth,
     }
   } else {
     NewMI = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
-        return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
-                       MIRBuilder.getDL(), TII.get(SPIRVOPcode))
-            .addDef(createTypeVReg(CurMF->getRegInfo()))
-            .addImm(BitWidth);
-      });
+      return BuildMI(MIRBuilder.getMBB(), *MIRBuilder.getInsertPt(),
+                     MIRBuilder.getDL(), TII.get(SPIRVOPcode))
+          .addDef(createTypeVReg(CurMF->getRegInfo()))
+          .addImm(BitWidth);
+    });
   }
   add(Ty, false, NewMI);
   return finishCreatingSPIRVType(Ty, NewMI);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index a2a2e7b28..959a65b2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -436,7 +436,8 @@ private:
   SPIRVType *getOpTypeInt(unsigned Width, MachineIRBuilder &MIRBuilder,
                           bool IsSigned = false);
 
-  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder, uint32_t Fptype);
+  SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder,
+                            uint32_t Fptype);
 
   SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 44024bde3..df5f5caf7 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1266,8 +1266,8 @@ void addInstrRequirements(const MachineInstr &MI,
     SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(1).getReg());
     if ((TypeDef->getOpcode() == SPIRV::OpTypeFloat) &&
         (TypeDef->getOperand(1).getImm() == 16)) {
-          Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
-        }
+      Reqs.addCapability(SPIRV::Capability::BFloat16DotProductKHR);
+    }
     break;
   }
   case SPIRV::OpTypeFloat: {
@@ -1276,7 +1276,7 @@ void addInstrRequirements(const MachineInstr &MI,
     if (BitWidth == 64)
       Reqs.addCapability(SPIRV::Capability::Float64);
     else if (BitWidth == 16) {
-      if(Fptype == 1) {
+      if (Fptype == 1) {
         Reqs.addExtension(SPIRV::Extension::SPV_KHR_bfloat16);
         Reqs.addCapability(SPIRV::Capability::BFloat16TypeKHR);
       } else {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants