-
Notifications
You must be signed in to change notification settings - Fork 14.9k
Add the support of Bfloat in LLT #154199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
YixingZhang007
wants to merge
7
commits into
llvm:main
Choose a base branch
from
YixingZhang007:add_llt_bfloat16_support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Add the support of Bfloat in LLT #154199
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-backend-spir-v Author: None (YixingZhang007) ChangesFull diff: https://github.com/llvm/llvm-project/pull/154199.diff 4 Files Affected:
diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
index d8e0848aff84d..0aec3d2537d9c 100644
--- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h
+++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
@@ -30,6 +30,7 @@
#include "llvm/CodeGenTypes/MachineValueType.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
#include <cassert>
namespace llvm {
@@ -41,7 +42,13 @@ class LLT {
public:
/// Get a low-level scalar or aggregate "bag of bits".
static constexpr LLT scalar(unsigned SizeInBits) {
- return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true,
+ return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/false,
+ ElementCount::getFixed(0), SizeInBits,
+ /*AddressSpace=*/0};
+ }
+
+ static constexpr LLT scalar_bfloat(unsigned SizeInBits) {
+ return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/true,
ElementCount::getFixed(0), SizeInBits,
/*AddressSpace=*/0};
}
@@ -49,7 +56,7 @@ class LLT {
/// Get a low-level token; just a scalar with zero bits (or no size).
static constexpr LLT token() {
return LLT{/*isPointer=*/false, /*isVector=*/false,
- /*isScalar=*/true, ElementCount::getFixed(0),
+ /*isScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0),
/*SizeInBits=*/0,
/*AddressSpace=*/0};
}
@@ -57,14 +64,14 @@ class LLT {
/// Get a low-level pointer in the given address space.
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
assert(SizeInBits > 0 && "invalid pointer size");
- return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false,
+ return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, /*isBfloat=*/false,
ElementCount::getFixed(0), SizeInBits, AddressSpace};
}
/// Get a low-level vector of some number of elements and element width.
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits) {
assert(!EC.isScalar() && "invalid number of vector elements");
- return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false,
+ return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, /*isBfloat=*/false,
EC, ScalarSizeInBits, /*AddressSpace=*/0};
}
@@ -75,11 +82,17 @@ class LLT {
return LLT{ScalarTy.isPointer(),
/*isVector=*/true,
/*isScalar=*/false,
+ /*isBfloat=*/false,
EC,
ScalarTy.getSizeInBits().getFixedValue(),
ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0};
}
+ // Get a 16-bit brain float value.
+ static constexpr LLT bfloat16() {
+ return scalar_bfloat(16);
+ }
+
/// Get a 16-bit IEEE half value.
/// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`.
static constexpr LLT float16() {
@@ -132,14 +145,14 @@ class LLT {
return scalarOrVector(EC, LLT::scalar(static_cast<unsigned>(ScalarSize)));
}
- explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar,
+ explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, bool isBfloat,
ElementCount EC, uint64_t SizeInBits,
unsigned AddressSpace)
: LLT() {
- init(isPointer, isVector, isScalar, EC, SizeInBits, AddressSpace);
+ init(isPointer, isVector, isScalar, isBfloat, EC, SizeInBits, AddressSpace);
}
explicit constexpr LLT()
- : IsScalar(false), IsPointer(false), IsVector(false), RawData(0) {}
+ : IsScalar(false), IsPointer(false), IsVector(false), IsBfloat(false), RawData(0) {}
LLVM_ABI explicit LLT(MVT VT);
@@ -154,6 +167,7 @@ class LLT {
constexpr bool isPointerOrPointerVector() const {
return IsPointer && isValid();
}
+ constexpr bool isBfloat() const { return IsBfloat; }
/// Returns the number of elements in a vector LLT. Must only be called on
/// vector types.
@@ -304,32 +318,35 @@ class LLT {
/// isScalar : 1
/// isPointer : 1
/// isVector : 1
- /// with 61 bits remaining for Kind-specific data, packed in bitfields
+ /// isBfloat : 1
+ /// with 60 bits remaining for Kind-specific data, packed in bitfields
/// as described below. As there isn't a simple portable way to pack bits
/// into bitfields, here the different fields in the packed structure is
/// described in static const *Field variables. Each of these variables
/// is a 2-element array, with the first element describing the bitfield size
/// and the second element describing the bitfield offset.
///
- /// +--------+---------+--------+----------+----------------------+
- /// |isScalar|isPointer|isVector| RawData |Notes |
- /// +--------+---------+--------+----------+----------------------+
- /// | 0 | 0 | 0 | 0 |Invalid |
- /// +--------+---------+--------+----------+----------------------+
- /// | 0 | 0 | 1 | 0 |Tombstone Key |
- /// +--------+---------+--------+----------+----------------------+
- /// | 0 | 1 | 0 | 0 |Empty Key |
- /// +--------+---------+--------+----------+----------------------+
- /// | 1 | 0 | 0 | 0 |Token |
- /// +--------+---------+--------+----------+----------------------+
- /// | 1 | 0 | 0 | non-zero |Scalar |
- /// +--------+---------+--------+----------+----------------------+
- /// | 0 | 1 | 0 | non-zero |Pointer |
- /// +--------+---------+--------+----------+----------------------+
- /// | 0 | 0 | 1 | non-zero |Vector of non-pointer |
- /// +--------+---------+--------+----------+----------------------+
- /// | 0 | 1 | 1 | non-zero |Vector of pointer |
- /// +--------+---------+--------+----------+----------------------+
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// |isScalar|isPointer|isVector| isBfloat | RawData |Notes |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 0 | 0 | 0 | 0 | 0 |Invalid |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 0 | 0 | 1 | 0 | 0 |Tombstone Key |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 0 | 1 | 0 | 0 | 0 |Empty Key |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 1 | 0 | 0 | 0 | 0 |Token |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 1 | 0 | 0 | 0 | non-zero |Scalar |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 1 | 0 | 0 | 1 | non-zero |Scalar (Bfloat 16) |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 0 | 1 | 0 | 0 | non-zero |Pointer |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 0 | 0 | 1 | 0 | non-zero |Vector of non-pointer |
+ /// +--------+---------+--------+----------+----------+----------------------+
+ /// | 0 | 1 | 1 0 | non-zero |Vector of pointer |
+ /// +--------+---------+--------+----------+----------+----------------------+
///
/// Everything else is reserved.
typedef int BitFieldInfo[2];
@@ -340,12 +357,12 @@ class LLT {
/// valid encodings, SizeInBits/SizeOfElement must be larger than 0.
/// * Non-pointer scalar (isPointer == 0 && isVector == 0):
/// SizeInBits: 32;
- static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 29};
+ static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 28};
/// * Pointer (isPointer == 1 && isVector == 0):
/// SizeInBits: 16;
/// AddressSpace: 24;
- static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 45};
- static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 21};
+ static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 44};
+ static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 20};
/// * Vector-of-non-pointer (isPointer == 0 && isVector == 1):
/// NumElements: 16;
/// SizeOfElement: 32;
@@ -361,7 +378,8 @@ class LLT {
uint64_t IsScalar : 1;
uint64_t IsPointer : 1;
uint64_t IsVector : 1;
- uint64_t RawData : 61;
+ uint64_t IsBfloat : 1;
+ uint64_t RawData : 60;
static constexpr uint64_t getMask(const BitFieldInfo FieldInfo) {
const int FieldSizeInBits = FieldInfo[0];
@@ -381,7 +399,7 @@ class LLT {
return getMask(FieldInfo) & (RawData >> FieldInfo[1]);
}
- constexpr void init(bool IsPointer, bool IsVector, bool IsScalar,
+ constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, bool IsBfloat,
ElementCount EC, uint64_t SizeInBits,
unsigned AddressSpace) {
assert(SizeInBits <= std::numeric_limits<unsigned>::max() &&
@@ -389,6 +407,7 @@ class LLT {
this->IsPointer = IsPointer;
this->IsVector = IsVector;
this->IsScalar = IsScalar;
+ this->IsBfloat = IsBfloat;
if (IsPointer) {
RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) |
maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo);
@@ -403,7 +422,7 @@ class LLT {
public:
constexpr uint64_t getUniqueRAWLLTData() const {
- return ((uint64_t)RawData) << 3 | ((uint64_t)IsScalar) << 2 |
+ return ((uint64_t)RawData) << 4 | ((uint64_t)IsBfloat) << 3 | ((uint64_t)IsScalar) << 2 |
((uint64_t)IsPointer) << 1 | ((uint64_t)IsVector);
}
};
diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp
index 936c9fbb2fff0..9fe3d3e43a47f 100644
--- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp
+++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp
@@ -27,7 +27,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
}
if (auto PTy = dyn_cast<PointerType>(&Ty)) {
- unsigned AddrSpace = PTy->getAddressSpace();
+ unsigned AddrSpace = PTy->isTokenTy();
return LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
}
@@ -36,6 +36,9 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
// concerned.
auto SizeInBits = DL.getTypeSizeInBits(&Ty);
assert(SizeInBits != 0 && "invalid zero-sized type");
+ if (Ty.isBFloatTy()) {
+ return LLT::scalar_bfloat(SizeInBits);
+ }
return LLT::scalar(SizeInBits);
}
diff --git a/llvm/lib/CodeGenTypes/LowLevelType.cpp b/llvm/lib/CodeGenTypes/LowLevelType.cpp
index 4785f2652b00e..a13363a337472 100644
--- a/llvm/lib/CodeGenTypes/LowLevelType.cpp
+++ b/llvm/lib/CodeGenTypes/LowLevelType.cpp
@@ -19,18 +19,21 @@ using namespace llvm;
LLT::LLT(MVT VT) {
if (VT.isVector()) {
bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector();
- init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
+ init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, /*isBfloat=*/false,
VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
/*AddressSpace=*/0);
} else if (VT.isValid() && !VT.isScalableTargetExtVT()) {
// Aggregates are no different from real scalars as far as GlobalISel is
// concerned.
- init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true,
+ MVT ElemVT = VT.getVectorElementType();
+ bool isElemBfloat = (ElemVT == MVT::bf16);
+ init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/isElemBfloat,
ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0);
} else {
IsScalar = false;
IsPointer = false;
IsVector = false;
+ IsBfloat = false;
RawData = 0;
}
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 8039cf0c432fa..51fc45446d50e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1262,6 +1262,17 @@ void addInstrRequirements(const MachineInstr &MI,
break;
}
case SPIRV::OpTypeFloat: {
+ // const MachineBasicBlock *MBB = MI.getParent();
+ // const MachineFunction *MF = MBB->getParent();
+ // const MachineRegisterInfo &MRI = MF->getRegInfo();
+ // const MachineOperand &MO = MI.getOperand(1);
+ // if (MO.isReg()) {
+ // LLT Ty = MRI.getType(MO.getReg());
+ // if(!Ty.isScalar()) {
+ // assert(1 && "hola, ur wrong");
+ // }
+ // }
+
unsigned BitWidth = MI.getOperand(1).getImm();
if (BitWidth == 64)
Reqs.addCapability(SPIRV::Capability::Float64);
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp,h -- llvm/include/llvm/CodeGenTypes/LowLevelType.h llvm/lib/CodeGen/LowLevelTypeUtils.cpp llvm/lib/CodeGenTypes/LowLevelType.cpp llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp View the diff from clang-format here.diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
index 0aec3d253..ac207ad39 100644
--- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h
+++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
@@ -42,21 +42,24 @@ class LLT {
public:
/// Get a low-level scalar or aggregate "bag of bits".
static constexpr LLT scalar(unsigned SizeInBits) {
- return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/false,
+ return LLT{/*isPointer=*/false, /*isVector=*/false,
+ /*isScalar=*/true, /*isBfloat=*/false,
ElementCount::getFixed(0), SizeInBits,
/*AddressSpace=*/0};
}
static constexpr LLT scalar_bfloat(unsigned SizeInBits) {
- return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, /*isBfloat=*/true,
+ return LLT{/*isPointer=*/false, /*isVector=*/false,
+ /*isScalar=*/true, /*isBfloat=*/true,
ElementCount::getFixed(0), SizeInBits,
/*AddressSpace=*/0};
}
/// Get a low-level token; just a scalar with zero bits (or no size).
static constexpr LLT token() {
- return LLT{/*isPointer=*/false, /*isVector=*/false,
- /*isScalar=*/true, /*isBfloat=*/false, ElementCount::getFixed(0),
+ return LLT{/*isPointer=*/false, /*isVector=*/false,
+ /*isScalar=*/true, /*isBfloat=*/false,
+ ElementCount::getFixed(0),
/*SizeInBits=*/0,
/*AddressSpace=*/0};
}
@@ -64,15 +67,25 @@ public:
/// Get a low-level pointer in the given address space.
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
assert(SizeInBits > 0 && "invalid pointer size");
- return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, /*isBfloat=*/false,
- ElementCount::getFixed(0), SizeInBits, AddressSpace};
+ return LLT{/*isPointer=*/true,
+ /*isVector=*/false,
+ /*isScalar=*/false,
+ /*isBfloat=*/false,
+ ElementCount::getFixed(0),
+ SizeInBits,
+ AddressSpace};
}
/// Get a low-level vector of some number of elements and element width.
static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits) {
assert(!EC.isScalar() && "invalid number of vector elements");
- return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, /*isBfloat=*/false,
- EC, ScalarSizeInBits, /*AddressSpace=*/0};
+ return LLT{/*isPointer=*/false,
+ /*isVector=*/true,
+ /*isScalar=*/false,
+ /*isBfloat=*/false,
+ EC,
+ ScalarSizeInBits,
+ /*AddressSpace=*/0};
}
/// Get a low-level vector of some number of elements and element type.
@@ -88,10 +101,8 @@ public:
ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0};
}
- // Get a 16-bit brain float value.
- static constexpr LLT bfloat16() {
- return scalar_bfloat(16);
- }
+ // Get a 16-bit brain float value.
+ static constexpr LLT bfloat16() { return scalar_bfloat(16); }
/// Get a 16-bit IEEE half value.
/// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`.
@@ -145,14 +156,15 @@ public:
return scalarOrVector(EC, LLT::scalar(static_cast<unsigned>(ScalarSize)));
}
- explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, bool isBfloat,
- ElementCount EC, uint64_t SizeInBits,
+ explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar,
+ bool isBfloat, ElementCount EC, uint64_t SizeInBits,
unsigned AddressSpace)
: LLT() {
init(isPointer, isVector, isScalar, isBfloat, EC, SizeInBits, AddressSpace);
}
explicit constexpr LLT()
- : IsScalar(false), IsPointer(false), IsVector(false), IsBfloat(false), RawData(0) {}
+ : IsScalar(false), IsPointer(false), IsVector(false), IsBfloat(false),
+ RawData(0) {}
LLVM_ABI explicit LLT(MVT VT);
@@ -327,7 +339,7 @@ private:
/// and the second element describing the bitfield offset.
///
/// +--------+---------+--------+----------+----------+----------------------+
- /// |isScalar|isPointer|isVector| isBfloat | RawData |Notes |
+ /// |isScalar|isPointer|isVector| isBfloat | RawData |Notes |
/// +--------+---------+--------+----------+----------+----------------------+
/// | 0 | 0 | 0 | 0 | 0 |Invalid |
/// +--------+---------+--------+----------+----------+----------------------+
@@ -399,8 +411,8 @@ private:
return getMask(FieldInfo) & (RawData >> FieldInfo[1]);
}
- constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, bool IsBfloat,
- ElementCount EC, uint64_t SizeInBits,
+ constexpr void init(bool IsPointer, bool IsVector, bool IsScalar,
+ bool IsBfloat, ElementCount EC, uint64_t SizeInBits,
unsigned AddressSpace) {
assert(SizeInBits <= std::numeric_limits<unsigned>::max() &&
"Not enough bits in LLT to represent size");
@@ -422,8 +434,9 @@ private:
public:
constexpr uint64_t getUniqueRAWLLTData() const {
- return ((uint64_t)RawData) << 4 | ((uint64_t)IsBfloat) << 3 | ((uint64_t)IsScalar) << 2 |
- ((uint64_t)IsPointer) << 1 | ((uint64_t)IsVector);
+ return ((uint64_t)RawData) << 4 | ((uint64_t)IsBfloat) << 3 |
+ ((uint64_t)IsScalar) << 2 | ((uint64_t)IsPointer) << 1 |
+ ((uint64_t)IsVector);
}
};
diff --git a/llvm/lib/CodeGenTypes/LowLevelType.cpp b/llvm/lib/CodeGenTypes/LowLevelType.cpp
index 8828135fc..9bc2e0f9e 100644
--- a/llvm/lib/CodeGenTypes/LowLevelType.cpp
+++ b/llvm/lib/CodeGenTypes/LowLevelType.cpp
@@ -19,16 +19,18 @@ using namespace llvm;
LLT::LLT(MVT VT) {
if (VT.isVector()) {
bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector();
- init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, /*isBfloat=*/false,
- VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
+ init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
+ /*isBfloat=*/false, VT.getVectorElementCount(),
+ VT.getVectorElementType().getSizeInBits(),
/*AddressSpace=*/0);
} else if (VT.isValid() && !VT.isScalableTargetExtVT()) {
// Aggregates are no different from real scalars as far as GlobalISel is
// concerned.
MVT ElemVT = VT.getVectorElementType();
bool isElemBfloat = (ElemVT == MVT::bf16);
- init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, /*isBfloat=*/false,
- ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0);
+ init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true,
+ /*isBfloat=*/false, ElementCount::getFixed(0), VT.getSizeInBits(),
+ /*AddressSpace=*/0);
} else {
IsScalar = false;
IsPointer = false;
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index dc00d97e2..4cf23963c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1268,7 +1268,7 @@ void addInstrRequirements(const MachineInstr &MI,
const MachineOperand &MO = MI.getOperand(1);
if (MO.isReg()) {
LLT Ty = MRI.getType(MO.getReg());
- if(Ty.isBfloat()) {
+ if (Ty.isBfloat()) {
assert(1 && "hola, ur wrong");
}
}
|
I think right after bfloat16 you may want to have vectorized bf16x2 or any other short fp types, and this solution do not scaling well for such cases |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.