Skip to content

Conversation

tgymnich
Copy link
Member

@tgymnich tgymnich commented Jan 10, 2025

Draft implementation of floating point support for LLT.

makes progress on #119667

TODO:

  • syntactic sugar for legalizer DSL to allow for easy type comparison based on size.
  • fix remainder of broken tests
  • update MIR tests
  • improve type inference pass for easy updating of MIR tests
  • Add one coherent flag that enables/disables FPInfo for a given backend
  • Format code
  • fix combiner

@tgymnich tgymnich self-assigned this Jan 10, 2025
Copy link

github-actions bot commented Jan 10, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 90e421516ee29ed475a052d0f5e96c2447e700d7 055634d7f4fe9d81c95ff4d0cd720456069e34bf --extensions h,cpp -- llvm/include/llvm/CodeGen/GlobalISel/InferTypeInfoPass.h llvm/lib/CodeGen/GlobalISel/InferTypeInfoPass.cpp llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h llvm/include/llvm/CodeGen/GlobalISel/Utils.h llvm/include/llvm/CodeGen/LowLevelTypeUtils.h llvm/include/llvm/CodeGenTypes/LowLevelType.h llvm/include/llvm/InitializePasses.h llvm/lib/CodeGen/GlobalISel/CallLowering.cpp llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp llvm/lib/CodeGen/GlobalISel/GlobalISel.cpp llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp llvm/lib/CodeGen/GlobalISel/Utils.cpp llvm/lib/CodeGen/LowLevelTypeUtils.cpp llvm/lib/CodeGen/MIRParser/MIParser.cpp llvm/lib/CodeGen/MachineVerifier.cpp llvm/lib/CodeGen/RegisterBankInfo.cpp llvm/lib/CodeGenTypes/LowLevelType.cpp llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp llvm/lib/Target/AMDGPU/AMDGPUGlobalISelDivergenceLowering.cpp llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.cpp llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp llvm/lib/Target/AMDGPU/AMDGPUPreLegalizerCombiner.cpp llvm/lib/Target/AMDGPU/AMDGPURegBankCombiner.cpp llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp
View the diff from clang-format here.
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
index dda353626c..0f7668db04 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizationArtifactCombiner.h
@@ -520,7 +520,8 @@ public:
       //  %9:_(i32) = G_BITCAST %7
       //  %4:_(i16), %5:_(i16) = G_UNMERGE_VALUES %9
 
-      if (CastSrcTy.isScalar() || SrcTy.isScalar() || DestTy.isVector() || DestTy == SrcTy.getScalarType())
+      if (CastSrcTy.isScalar() || SrcTy.isScalar() || DestTy.isVector() ||
+          DestTy == SrcTy.getScalarType())
         return false;
 
       const unsigned NewNumDefs1 = CastSrcTy.getNumElements();
@@ -531,27 +532,26 @@ public:
 
       SmallVector<Register, 8> NewUnmergeRegs(NewNumDefs1);
       for (unsigned Idx = 0; Idx < NewNumDefs1; ++Idx)
-        NewUnmergeRegs[Idx] = MRI.createGenericVirtualRegister(CastSrcTy.getElementType());
+        NewUnmergeRegs[Idx] =
+            MRI.createGenericVirtualRegister(CastSrcTy.getElementType());
 
       Builder.setInstr(MI);
       auto NewUnmerge = Builder.buildUnmerge(NewUnmergeRegs, CastSrcReg);
 
-
       SmallVector<Register, 8> DstRegs(NumDefs);
       for (unsigned Idx = 0; Idx < NumDefs; ++Idx)
-          DstRegs[Idx] = MI.getOperand(Idx).getReg();
+        DstRegs[Idx] = MI.getOperand(Idx).getReg();
 
+      auto *It = DstRegs.begin();
 
-      auto* It = DstRegs.begin();
-      
-      for (auto& Def : NewUnmerge->all_defs()) {
+      for (auto &Def : NewUnmerge->all_defs()) {
         auto Bitcast = Builder.buildBitcast(SrcTy.getElementType(), Def);
-        auto* Begin = It;
+        auto *Begin = It;
         It += NewNumDefs2;
         ArrayRef Regs(Begin, It);
         Builder.buildUnmerge(Regs, Bitcast);
       }
-      
+
       UpdatedDefs.append(NewUnmergeRegs.begin(), NewUnmergeRegs.end());
       UpdatedDefs.append(DstRegs.begin(), DstRegs.end());
       markInstAndDefDead(MI, CastMI, DeadInsts);
@@ -1215,7 +1215,7 @@ public:
           DstRegs.push_back(MI.getReg(DefIdx));
 
         LLT MergeDstTy = MRI.getType(SrcDef->getOperand(0).getReg());
-        
+
         if (ConvertOp && DestTy != MergeDstTy) {
 
           // This is a vector that is being split and casted. Extract to the
@@ -1237,7 +1237,8 @@ public:
           // %7(<2 x s16>), %7(<2 x s16>) = G_UNMERGE_VALUES %9
 
           Register TmpReg = MRI.createGenericVirtualRegister(MergeEltTy);
-          assert(MRI.getType(TmpReg) != MRI.getType(MergeI->getOperand(Idx + 1).getReg()));
+          assert(MRI.getType(TmpReg) !=
+                 MRI.getType(MergeI->getOperand(Idx + 1).getReg()));
           Builder.buildInstr(ConvertOp, {TmpReg},
                              {MergeI->getOperand(Idx + 1).getReg()});
           Builder.buildUnmerge(DstRegs, TmpReg);
diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
index c5142b84d8..8ae817eec6 100644
--- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
@@ -432,9 +432,10 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
         B.buildBitcast(OrigRegs[0], Merge);
       } else if (!OrigTy.isFloat() && PartLLT.isFloat()) {
         SmallVector<Register> CastRegs(Regs.size());
-        for (auto&& [Idx, Reg]: enumerate(Regs))
-          CastRegs[Idx] = B.buildBitcast(PartLLT.changeToInteger(), Reg).getReg(0);
-        
+        for (auto &&[Idx, Reg] : enumerate(Regs))
+          CastRegs[Idx] =
+              B.buildBitcast(PartLLT.changeToInteger(), Reg).getReg(0);
+
         B.buildMergeValues(OrigRegs[0], CastRegs);
       } else {
         B.buildMergeValues(OrigRegs[0], Regs);
@@ -607,7 +608,9 @@ static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
       }
 
       if (SrcTy.isFloatVector() && !PartTy.isFloat())
-        Unmerge = B.buildBitcast(SrcTy.getElementType().changeToInteger(), Unmerge).getReg(0);
+        Unmerge =
+            B.buildBitcast(SrcTy.getElementType().changeToInteger(), Unmerge)
+                .getReg(0);
 
       B.buildAnyExt(DstRegs[i], Unmerge);
     }
@@ -642,7 +645,7 @@ static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef<Register> DstRegs,
                                  SrcTy.getNumElements()));
     if (SrcTy.isFloatVector())
       SrcReg = B.buildBitcast(SrcTy.changeToInteger(), SrcReg).getReg(0);
-    
+
     auto Ext = B.buildAnyExt(ExtTy, SrcReg);
     B.buildUnmerge(DstRegs, Ext);
     return;
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 4e7e77a30e..180ec474fb 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -3670,7 +3670,8 @@ LegalizerHelper::bitcastExtractVectorElt(MachineInstr &MI, unsigned TypeIdx,
     // Shift the wide element to get the target element.
     auto ExtractedBits = MIRBuilder.buildLShr(NewEltTy, WideElt, OffsetBits);
     if (DstTy.isFloat()) {
-      auto Trunc = MIRBuilder.buildTrunc(DstTy.changeToInteger(), ExtractedBits);
+      auto Trunc =
+          MIRBuilder.buildTrunc(DstTy.changeToInteger(), ExtractedBits);
       MIRBuilder.buildBitcast(Dst, Trunc);
     } else {
       MIRBuilder.buildTrunc(Dst, ExtractedBits);
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index ffe8b06647..2d2fe674ce 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -570,7 +570,8 @@ MachineInstrBuilder MachineIRBuilder::buildExtOrTrunc(unsigned ExtOpc,
   return buildInstr(Opcode, Res, Op);
 }
 
-MachineInstrBuilder MachineIRBuilder::buildTruncLike(const DstOp &Res, const SrcOp &Op) {
+MachineInstrBuilder MachineIRBuilder::buildTruncLike(const DstOp &Res,
+                                                     const SrcOp &Op) {
   LLT DstTy = Res.getLLTTy(*getMRI());
   LLT SrcTy = Op.getLLTTy(*getMRI());
 
@@ -582,14 +583,15 @@ MachineInstrBuilder MachineIRBuilder::buildTruncLike(const DstOp &Res, const Src
   if (DstIsFloat && !SrcIsFloat) {
     auto Trunc = buildTrunc(DstTy.changeToInteger(), Op);
     return buildBitcast(Res, Trunc);
-  } 
-  
+  }
+
   if (!DstIsFloat && SrcIsFloat) {
     auto Bitcast = buildBitcast(SrcTy.changeToInteger(), Op);
     return buildTrunc(Res, Bitcast);
   }
 
-  unsigned Opcode = DstIsFloat ? TargetOpcode::G_FPTRUNC : TargetOpcode::G_TRUNC;
+  unsigned Opcode =
+      DstIsFloat ? TargetOpcode::G_FPTRUNC : TargetOpcode::G_TRUNC;
   return buildInstr(Opcode, Res, Op);
 }
 
diff --git a/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp b/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
index e504722539..3cc7f73fc4 100644
--- a/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/RegBankSelect.cpp
@@ -216,7 +216,8 @@ bool RegBankSelect::repairReg(
       if (RegTy.isFloat()) {
         const RegisterBank *Bank = ValMapping.BreakDown[0].RegBank;
         MergeReg =
-            MIRBuilder.buildBitcast({Bank, RegTy.changeToInteger()}, MO.getReg())
+            MIRBuilder
+                .buildBitcast({Bank, RegTy.changeToInteger()}, MO.getReg())
                 .getReg(0);
       }
 
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 8ed8f62678..a12fa788ae 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -484,14 +484,16 @@ llvm::getDefSrcRegIgnoringCopies(Register Reg, const MachineRegisterInfo &MRI) {
 }
 
 std::optional<DefinitionAndSourceRegister>
-llvm::getDefSrcRegIgnoringBitcasts(Register Reg, const MachineRegisterInfo &MRI) {
+llvm::getDefSrcRegIgnoringBitcasts(Register Reg,
+                                   const MachineRegisterInfo &MRI) {
   Register DefSrcReg = Reg;
   auto *DefMI = MRI.getVRegDef(Reg);
   auto DstTy = MRI.getType(DefMI->getOperand(0).getReg());
   if (!DstTy.isValid())
     return std::nullopt;
   unsigned Opc = DefMI->getOpcode();
-  while (Opc == TargetOpcode::G_BITCAST || isPreISelGenericOptimizationHint(Opc)) {
+  while (Opc == TargetOpcode::G_BITCAST ||
+         isPreISelGenericOptimizationHint(Opc)) {
     Register SrcReg = DefMI->getOperand(1).getReg();
     auto SrcTy = MRI.getType(SrcReg);
     if (!SrcTy.isValid())
@@ -504,10 +506,10 @@ llvm::getDefSrcRegIgnoringBitcasts(Register Reg, const MachineRegisterInfo &MRI)
 }
 
 MachineInstr *llvm::getDefIgnoringBitcasts(Register Reg,
-  const MachineRegisterInfo &MRI) {
-std::optional<DefinitionAndSourceRegister> DefSrcReg =
-getDefSrcRegIgnoringBitcasts(Reg, MRI);
-return DefSrcReg ? DefSrcReg->MI : nullptr;
+                                           const MachineRegisterInfo &MRI) {
+  std::optional<DefinitionAndSourceRegister> DefSrcReg =
+      getDefSrcRegIgnoringBitcasts(Reg, MRI);
+  return DefSrcReg ? DefSrcReg->MI : nullptr;
 }
 
 MachineInstr *llvm::getDefIgnoringCopies(Register Reg,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.cpp b/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.cpp
index bee4553f57..1a3f88a885 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUGlobalISelUtils.cpp
@@ -158,7 +158,8 @@ static Register buildReadAnyLane(MachineIRBuilder &B, Register VgprSrc,
   unmergeReadAnyLane(B, SgprDstParts, getReadAnyLaneSplitTy(Ty), VgprSrc, RBI);
 
   if (Ty.getScalarType().isFloat()) {
-    auto Merge = B.buildMergeLikeInstr({SgprRB, Ty.changeToInteger()}, SgprDstParts);
+    auto Merge =
+        B.buildMergeLikeInstr({SgprRB, Ty.changeToInteger()}, SgprDstParts);
     return B.buildBitcast({SgprRB, Ty}, Merge).getReg(0);
   }
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
index 18db3d4fab..0294b15baf 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
@@ -4287,7 +4287,8 @@ AMDGPUInstructionSelector::selectVOP3PModsImpl(
   if (MI->getOpcode() == AMDGPU::G_FNEG &&
       // It's possible to see an f32 fneg here, but unlikely.
       // TODO: Treat f32 fneg as only high bit.
-      MRI.getType(MI->getOperand(0).getReg()) == LLT::fixed_vector(2, LLT::float16())) {
+      MRI.getType(MI->getOperand(0).getReg()) ==
+          LLT::fixed_vector(2, LLT::float16())) {
     Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
     Src = MI->getOperand(1).getReg();
     MI = MRI.getVRegDef(Src);
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 79173b2107..9e5d241d5a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -1916,10 +1916,11 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
                        scalarOrEltNarrowerThan(VecTypeIdx, 32)),
                    bitcastToVectorElementI32(VecTypeIdx))
         //.bitcastIf(vectorSmallerThan(1, 32), bitcastToScalar(1))
-        .bitcastIf(all(sizeIsMultipleOf32(VecTypeIdx),
-        scalarOrEltWiderThan(VecTypeIdx, 64)),
+        .bitcastIf(
+            all(sizeIsMultipleOf32(VecTypeIdx),
+                scalarOrEltWiderThan(VecTypeIdx, 64)),
             [=](const LegalityQuery &Query) {
-              // For > 64-bit element types, try to turn this into a 
+              // For > 64-bit element types, try to turn this into a
               // 64-bit element vector since we may be able to do better
               // indexing if this is scalar. If not, fall back to 32.
               const LLT EltTy = Query.Types[EltTypeIdx];
@@ -2025,15 +2026,15 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
 
   // FIXME: Clamp maximum size
   getActionDefinitionsBuilder(G_CONCAT_VECTORS)
-        .legalIf(all(isRegisterType(ST, 0), isRegisterType(ST, 1)))
+      .legalIf(all(isRegisterType(ST, 0), isRegisterType(ST, 1)))
       .clampMaxNumElements(0, I32, 32)
-        .clampMaxNumElements(0, F32, 32)
+      .clampMaxNumElements(0, F32, 32)
       .clampMaxNumElements(1, I16, 2)  // TODO: Make 4?
       .clampMaxNumElements(1, F16, 2)  // TODO: Make 4?
-        .clampMaxNumElements(1, BF16, 2) // TODO: Make 4?
+      .clampMaxNumElements(1, BF16, 2) // TODO: Make 4?
       .clampMaxNumElements(0, I16, 64)
       .clampMaxNumElements(0, F16, 64)
-        .clampMaxNumElements(0, BF16, 64);
+      .clampMaxNumElements(0, BF16, 64);
 
   getActionDefinitionsBuilder(G_SHUFFLE_VECTOR).lower();
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index c227b859b9..213358b99a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -1998,7 +1998,8 @@ bool AMDGPURegisterBankInfo::foldExtractEltToCmpSelect(
 
   if (VecTy.isFloatVector() && !EltTy.isFloat()) {
     auto ClassOrBank = MRI.getRegClassOrRegBank(VecReg);
-    VecReg = B.buildBitcast({ClassOrBank, VecTy.changeToInteger()}, VecReg).getReg(0);
+    VecReg = B.buildBitcast({ClassOrBank, VecTy.changeToInteger()}, VecReg)
+                 .getReg(0);
   }
 
   auto UnmergeToEltTy = B.buildUnmerge(EltTy, VecReg);

@tgymnich
Copy link
Member Author

@arsenm could you please have a look at this PR and provide some early feedback.

@tgymnich tgymnich requested a review from arsenm February 26, 2025 12:07
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my browser is dying on this PR. Can you start with the initial LLT representation and parsing patch?

@@ -0,0 +1,43 @@
#ifndef LLVM_CODEGEN_GLOBALISEL_INFERTYPEINFOPASS_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is necessary, can we keep the initial patch to just the LLT changes

enum class FPVariant {
IEEE_FLOAT = 0x0,
BRAIN_FLOAT = 0x1, // BRAIN_FLOAT
VARIANT_FLOAT_2 = 0x2, // PPC_FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the commented name instead of VARIANT_*?

Comment on lines 72 to 73
assert(false && "Type is already a vector type");
return Ty;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dead return. Also use llvm_unreachable instead of assert false

Comment on lines 572 to 580
if (BVType.isFloatVector()) {
BuildVec = B.buildBitcast(BVType.dropType(), BuildVec).getReg(0);
}

LLT OrigTy = MRI.getType(OrigRegs[0]);
if (OrigTy.isFloatVector()) {
B.buildBitcast(OrigRegs[0], B.buildTrunc(OrigTy.dropType(), BuildVec));
} else {
B.buildTrunc(OrigRegs[0], BuildVec);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you cast the elements to avoid introducing another vector cast? Also we probably should have buildBitcast avoid the same type cases, like the dag does

Comment on lines 1959 to 1960
} else if (Token.range().front() == 'f') {
auto ScalarSize = APSInt(Token.range().drop_front()).getZExtValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use consume_front? These also need parsing tests, particularly for the error cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants