diff --git a/llvm/include/llvm/CodeGen/Analysis.h b/llvm/include/llvm/CodeGen/Analysis.h index 98b52579d03b7..5c36ecd26719e 100644 --- a/llvm/include/llvm/CodeGen/Analysis.h +++ b/llvm/include/llvm/CodeGen/Analysis.h @@ -101,7 +101,7 @@ inline void ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, /// If Offsets is non-null, it points to a vector to be filled in /// with the in-memory offsets of each of the individual values. /// -void computeValueLLTs(const DataLayout &DL, Type &Ty, +void computeValueLLTs(const TargetLowering &TLI, const DataLayout &DL, Type &Ty, SmallVectorImpl &ValueTys, SmallVectorImpl *Offsets = nullptr, uint64_t StartingOffset = 0); diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h index 3d7ccd55ee042..5673fd5168477 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h @@ -642,6 +642,8 @@ class IRTranslator : public MachineFunctionPass { StackProtectorDescriptor SPDescriptor; + bool mayTranslateUserTypes(const User &U) const; + /// Switch analysis and optimization. class GISelSwitchLowering : public SwitchCG::SwitchLowering { public: diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h index fd72a3898562e..7b6b5dbf0be05 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -380,7 +380,8 @@ LLVM_ABI LegalizeMutation changeElementCountTo(unsigned TypeIdx, /// Keep the same scalar or element type as \p TypeIdx, but take the number of /// elements from \p Ty. -LLVM_ABI LegalizeMutation changeElementCountTo(unsigned TypeIdx, LLT Ty); +LLVM_ABI LegalizeMutation changeElementCountTo(unsigned TypeIdx, + ElementCount EC); /// Change the scalar size or element size to have the same scalar size as type /// index \p FromIndex. Unlike changeElementTo, this discards pointer types and diff --git a/llvm/include/llvm/CodeGen/LowLevelTypeUtils.h b/llvm/include/llvm/CodeGen/LowLevelTypeUtils.h index 51a298eb8b247..4092509948ae6 100644 --- a/llvm/include/llvm/CodeGen/LowLevelTypeUtils.h +++ b/llvm/include/llvm/CodeGen/LowLevelTypeUtils.h @@ -27,7 +27,8 @@ class Type; struct fltSemantics; /// Construct a low-level type based on an LLVM type. -LLVM_ABI LLT getLLTForType(Type &Ty, const DataLayout &DL); +LLVM_ABI LLT getLLTForType(Type &Ty, const DataLayout &DL, + bool AllowExtendedLLT = false); /// Get a rough equivalent of an MVT for a given LLT. MVT can't distinguish /// pointers, so these will convert to a plain integer. @@ -36,11 +37,11 @@ LLVM_ABI EVT getApproximateEVTForLLT(LLT Ty, LLVMContext &Ctx); /// Get a rough equivalent of an LLT for a given MVT. LLT does not yet support /// scalarable vector types, and will assert if used. -LLVM_ABI LLT getLLTForMVT(MVT Ty); +LLVM_ABI LLT getLLTForMVT(MVT Ty, bool AllowExtendedLLT = false); /// Get the appropriate floating point arithmetic semantic based on the bit size /// of the given scalar LLT. LLVM_ABI const llvm::fltSemantics &getFltSemanticForLLT(LLT Ty); -} +} // namespace llvm #endif // LLVM_CODEGEN_LOWLEVELTYPEUTILS_H diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 3a51b9e6917c3..3224630e2b341 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -416,6 +416,9 @@ class LLVM_ABI TargetLoweringBase { /// amounts, returns MVT::i32. EVT getShiftAmountTy(EVT LHSTy, const DataLayout &DL) const; + virtual LLT getLLTForType(Type &Ty, const DataLayout &DL) const; + virtual LLT getLLTForMVT(MVT Ty) const; + /// Return the preferred type to use for a shift opcode, given the shifted /// amount type is \p ShiftValueTy. LLVM_READONLY diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h index d8e0848aff84d..3378baf680230 100644 --- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h +++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h @@ -39,68 +39,149 @@ class raw_ostream; class LLT { public: + enum class FPVariant { + IEEE_FLOAT = 0x0, + BRAIN_FLOAT = 0x1, // BRAIN_FLOAT + PPC128_FLOAT = 0x2, // PPC128_FLOAT + EXTENDED_FP80 = 0x3, // FP80 + TENSOR_FLOAT32 = 0x4, // TENSOR_FLOAT32 + VARIANT_FLOAT_5 = 0x5, // UNASSIGNED + VARIANT_FLOAT_6 = 0x6, // UNASSIGNED + VARIANT_FLOAT_7 = 0x7, // UNASSIGNED + }; + + enum class Kind : uint64_t { + INVALID = 0b0000, + ANY_SCALAR = 0b0001, + INTEGER = 0b0010, + FLOAT = 0b0011, + POINTER = 0b0100, + VECTOR_ANY = 0b0101, + VECTOR_INTEGER = 0b0110, + VECTOR_FLOAT = 0b0111, + VECTOR_POINTER = 0b1000, + }; + + constexpr static Kind toVector(Kind Ty) { + if (Ty == Kind::POINTER) + return Kind::VECTOR_POINTER; + + if (Ty == Kind::INTEGER) + return Kind::VECTOR_INTEGER; + + if (Ty == Kind::FLOAT) + return Kind::VECTOR_FLOAT; + + return Kind::VECTOR_ANY; + } + + constexpr static Kind toScalar(Kind Ty) { + if (Ty == Kind::VECTOR_POINTER) + return Kind::POINTER; + + if (Ty == Kind::VECTOR_INTEGER) + return Kind::INTEGER; + + if (Ty == Kind::VECTOR_FLOAT) + return Kind::FLOAT; + + return Kind::ANY_SCALAR; + } + /// Get a low-level scalar or aggregate "bag of bits". static constexpr LLT scalar(unsigned SizeInBits) { - return LLT{/*isPointer=*/false, /*isVector=*/false, /*isScalar=*/true, - ElementCount::getFixed(0), SizeInBits, - /*AddressSpace=*/0}; + return LLT{Kind::ANY_SCALAR, ElementCount::getFixed(0), SizeInBits, + /*AddressSpace=*/0, static_cast(0)}; + } + + static constexpr LLT integer(unsigned SizeInBits) { + return LLT{Kind::INTEGER, ElementCount::getFixed(0), SizeInBits, + /*AddressSpace=*/0, static_cast(0)}; + } + + static constexpr LLT floatingPoint(unsigned SizeInBits, FPVariant FP) { + return LLT{Kind::FLOAT, ElementCount::getFixed(0), SizeInBits, + /*AddressSpace=*/0, FP}; } /// Get a low-level token; just a scalar with zero bits (or no size). static constexpr LLT token() { - return LLT{/*isPointer=*/false, /*isVector=*/false, - /*isScalar=*/true, ElementCount::getFixed(0), + return LLT{Kind::ANY_SCALAR, ElementCount::getFixed(0), /*SizeInBits=*/0, - /*AddressSpace=*/0}; + /*AddressSpace=*/0, static_cast(0)}; } /// Get a low-level pointer in the given address space. static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits) { assert(SizeInBits > 0 && "invalid pointer size"); - return LLT{/*isPointer=*/true, /*isVector=*/false, /*isScalar=*/false, - ElementCount::getFixed(0), SizeInBits, AddressSpace}; + return LLT{Kind::POINTER, ElementCount::getFixed(0), SizeInBits, + AddressSpace, static_cast(0)}; } /// Get a low-level vector of some number of elements and element width. static constexpr LLT vector(ElementCount EC, unsigned ScalarSizeInBits) { assert(!EC.isScalar() && "invalid number of vector elements"); - return LLT{/*isPointer=*/false, /*isVector=*/true, /*isScalar=*/false, - EC, ScalarSizeInBits, /*AddressSpace=*/0}; + return LLT{Kind::VECTOR_ANY, EC, ScalarSizeInBits, + /*AddressSpace=*/0, static_cast(0)}; } /// Get a low-level vector of some number of elements and element type. static constexpr LLT vector(ElementCount EC, LLT ScalarTy) { assert(!EC.isScalar() && "invalid number of vector elements"); assert(!ScalarTy.isVector() && "invalid vector element type"); - return LLT{ScalarTy.isPointer(), - /*isVector=*/true, - /*isScalar=*/false, - EC, - ScalarTy.getSizeInBits().getFixedValue(), - ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; + + Kind Info = toVector(ScalarTy.Info); + return LLT{Info, EC, ScalarTy.getSizeInBits().getFixedValue(), + ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0, + ScalarTy.isFloat() ? ScalarTy.getFPVariant() + : static_cast(0)}; + } + // Get a 8-bit brain float value. + static constexpr LLT bfloat8() { + return floatingPoint(8, FPVariant::BRAIN_FLOAT); + } + + // Get a 16-bit brain float value. + static constexpr LLT bfloat16() { + return floatingPoint(16, FPVariant::BRAIN_FLOAT); } /// Get a 16-bit IEEE half value. - /// TODO: Add IEEE semantics to type - This currently returns a simple `scalar(16)`. static constexpr LLT float16() { - return scalar(16); + return floatingPoint(16, FPVariant::IEEE_FLOAT); } /// Get a 32-bit IEEE float value. static constexpr LLT float32() { - return scalar(32); + return floatingPoint(32, FPVariant::IEEE_FLOAT); } /// Get a 64-bit IEEE double value. static constexpr LLT float64() { - return scalar(64); + return floatingPoint(64, FPVariant::IEEE_FLOAT); + } + + /// Get a 80-bit X86 floating point value. + static constexpr LLT x86fp80() { + return floatingPoint(80, FPVariant::EXTENDED_FP80); + } + + /// Get a 128-bit IEEE quad value. + static constexpr LLT float128() { + return floatingPoint(128, FPVariant::IEEE_FLOAT); + } + + /// Get a 128-bit PowerPC double double value. + static constexpr LLT ppcf128() { + return floatingPoint(128, FPVariant::PPC128_FLOAT); } /// Get a low-level fixed-width vector of some number of elements and element /// width. static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits) { - return vector(ElementCount::getFixed(NumElements), ScalarSizeInBits); + return vector(ElementCount::getFixed(NumElements), + LLT::scalar(ScalarSizeInBits)); } /// Get a low-level fixed-width vector of some number of elements and element @@ -113,7 +194,8 @@ class LLT { /// width. static constexpr LLT scalable_vector(unsigned MinNumElements, unsigned ScalarSizeInBits) { - return vector(ElementCount::getScalable(MinNumElements), ScalarSizeInBits); + return vector(ElementCount::getScalable(MinNumElements), + LLT::scalar(ScalarSizeInBits)); } /// Get a low-level scalable vector of some number of elements and element @@ -132,27 +214,82 @@ class LLT { return scalarOrVector(EC, LLT::scalar(static_cast(ScalarSize))); } - explicit constexpr LLT(bool isPointer, bool isVector, bool isScalar, - ElementCount EC, uint64_t SizeInBits, - unsigned AddressSpace) + explicit constexpr LLT(Kind Info, ElementCount EC, uint64_t SizeInBits, + unsigned AddressSpace, FPVariant FP) : LLT() { - init(isPointer, isVector, isScalar, EC, SizeInBits, AddressSpace); + init(Info, EC, SizeInBits, AddressSpace, FP); } - explicit constexpr LLT() - : IsScalar(false), IsPointer(false), IsVector(false), RawData(0) {} - LLVM_ABI explicit LLT(MVT VT); + LLVM_ABI explicit LLT(MVT VT, bool AllowExtendedLLT = false); + explicit constexpr LLT() : Info(static_cast(0)), RawData(0) {} - constexpr bool isValid() const { return IsScalar || RawData != 0; } - constexpr bool isScalar() const { return IsScalar; } - constexpr bool isToken() const { return IsScalar && RawData == 0; }; - constexpr bool isVector() const { return isValid() && IsVector; } + constexpr bool isValid() const { return isToken() || RawData != 0; } + constexpr bool isScalar() const { + return Info == Kind::ANY_SCALAR || Info == Kind::INTEGER || + Info == Kind::FLOAT; + } + constexpr bool isScalar(unsigned Size) const { + return isScalar() && getScalarSizeInBits() == Size; + } + constexpr bool isFloat() const { return isValid() && Info == Kind::FLOAT; } + constexpr bool isFloat(unsigned Size) const { + return isFloat() && getScalarSizeInBits() == Size; + } + constexpr bool isVariantFloat() const { + return isFloat() && getFPVariant() != FPVariant::IEEE_FLOAT; + } + constexpr bool isVariantFloat(FPVariant Variant) const { + return isFloat() && getFPVariant() == Variant; + } + constexpr bool isVariantFloat(unsigned Size, FPVariant Variant) const { + return isVariantFloat(Variant) && getScalarSizeInBits() == Size; + } + constexpr bool isFloatVector() const { + return isVector() && Info == Kind::VECTOR_FLOAT; + } + constexpr bool isIEEEFloat(unsigned Size) const { + return isVariantFloat(Size, FPVariant::IEEE_FLOAT); + } + constexpr bool isBFloat(unsigned Size) const { + return isVariantFloat(Size, FPVariant::BRAIN_FLOAT); + } + constexpr bool isX86FP80() const { + return isVariantFloat(80, FPVariant::EXTENDED_FP80); + } + constexpr bool isPPCF128() const { + return isVariantFloat(128, FPVariant::PPC128_FLOAT); + } + constexpr bool isToken() const { + return Info == Kind::ANY_SCALAR && RawData == 0; + } + constexpr bool isAnyScalar() const { + return isValid() && Info == Kind::ANY_SCALAR; + } + constexpr bool isVectorAny() const { + return isVector() && Info == Kind::VECTOR_ANY; + } + constexpr bool isInteger() const { + return isValid() && Info == Kind::INTEGER; + } + constexpr bool isInteger(unsigned Size) const { + return isInteger() && getScalarSizeInBits() == Size; + } + constexpr bool isIntegerVector() const { + return isVector() && Info == Kind::VECTOR_INTEGER; + } + constexpr bool isVector() const { + return isValid() && + (Info == Kind::VECTOR_ANY || Info == Kind::VECTOR_INTEGER || + Info == Kind::VECTOR_FLOAT || Info == Kind::VECTOR_POINTER); + } constexpr bool isPointer() const { - return isValid() && IsPointer && !IsVector; + return isValid() && Info == Kind::POINTER; + } + constexpr bool isPointerVector() const { + return isVector() && Info == Kind::VECTOR_POINTER; } - constexpr bool isPointerVector() const { return IsPointer && isVector(); } constexpr bool isPointerOrPointerVector() const { - return IsPointer && isValid(); + return isPointer() || isPointerVector(); } /// Returns the number of elements in a vector LLT. Must only be called on @@ -177,12 +314,18 @@ class LLT { /// if the LLT is not a vector type. constexpr bool isFixedVector() const { return isVector() && !isScalable(); } + constexpr bool isFixedVector(unsigned NumElements, + unsigned ScalarSize) const { + return isFixedVector() && getNumElements() == NumElements && + getScalarSizeInBits() == ScalarSize; + } + /// Returns true if the LLT is a scalable vector. Returns false otherwise, /// even if the LLT is not a vector type. constexpr bool isScalableVector() const { return isVector() && isScalable(); } constexpr ElementCount getElementCount() const { - assert(IsVector && "cannot get number of elements on scalar/aggregate"); + assert(isVector() && "cannot get number of elements on scalar/aggregate"); return ElementCount::get(getFieldValue(VectorElementsFieldInfo), isScalable()); } @@ -207,6 +350,15 @@ class LLT { return isVector() ? getElementType() : *this; } + constexpr FPVariant getFPVariant() const { + assert((isFloat() || isFloatVector()) && + "cannot get FP info for non float type"); + + return FPVariant(getFieldValue(FPFieldInfo)); + } + + constexpr Kind getKind() const { return Info; } + /// If this type is a vector, return a vector with the same number of elements /// but the new element type. Otherwise, return the new element type. constexpr LLT changeElementType(LLT NewEltTy) const { @@ -217,10 +369,14 @@ class LLT { /// but the new element size. Otherwise, return the new element type. Invalid /// for pointer types. For pointer types, use changeElementType. constexpr LLT changeElementSize(unsigned NewEltSize) const { - assert(!isPointerOrPointerVector() && + assert(!isPointerOrPointerVector() && !(isFloat() || isFloatVector()) && "invalid to directly change element size for pointers"); - return isVector() ? LLT::vector(getElementCount(), NewEltSize) - : LLT::scalar(NewEltSize); + return isVector() + ? LLT::vector(getElementCount(), getElementType().isInteger() + ? LLT::integer(NewEltSize) + : LLT::scalar(NewEltSize)) + : isInteger() ? LLT::integer(NewEltSize) + : LLT::scalar(NewEltSize); } /// Return a vector or scalar with the same element type and the new element @@ -229,6 +385,10 @@ class LLT { return LLT::scalarOrVector(EC, getScalarType()); } + constexpr LLT changeElementCount(unsigned NumElements) const { + return changeElementCount(ElementCount::getFixed(NumElements)); + } + /// Return a type that is \p Factor times smaller. Reduces the number of /// elements if this is a vector, or the bitwidth for scalar/pointers. Does /// not attempt to handle cases that aren't evenly divisible. @@ -243,6 +403,9 @@ class LLT { } assert(getScalarSizeInBits() % Factor == 0); + if (isInteger()) + return integer(getScalarSizeInBits() / Factor); + return scalar(getScalarSizeInBits() / Factor); } @@ -277,10 +440,26 @@ class LLT { /// Returns the vector's element type. Only valid for vector types. constexpr LLT getElementType() const { assert(isVector() && "cannot get element type of scalar/aggregate"); - if (IsPointer) + if (isPointerVector()) return pointer(getAddressSpace(), getScalarSizeInBits()); - else - return scalar(getScalarSizeInBits()); + + if (isFloatVector()) + return floatingPoint(getScalarSizeInBits(), getFPVariant()); + + if (isIntegerVector()) + return integer(getScalarSizeInBits()); + + return scalar(getScalarSizeInBits()); + } + + constexpr LLT changeToInteger() const { + if (isPointer() || isPointerVector()) + return *this; + + if (isVector()) + return vector(getElementCount(), LLT::integer(getScalarSizeInBits())); + + return integer(getSizeInBits()); } LLVM_ABI void print(raw_ostream &OS) const; @@ -290,8 +469,14 @@ class LLT { #endif constexpr bool operator==(const LLT &RHS) const { - return IsPointer == RHS.IsPointer && IsVector == RHS.IsVector && - IsScalar == RHS.IsScalar && RHS.RawData == RawData; + if (isAnyScalar() || RHS.isAnyScalar()) { + return isScalar() == RHS.isScalar() && RawData == RHS.RawData; + } + if (isVector() && RHS.isVector()) { + return getElementType() == RHS.getElementType() && + getElementCount() == RHS.getElementCount(); + } + return Info == RHS.Info && RawData == RHS.RawData; } constexpr bool operator!=(const LLT &RHS) const { return !(*this == RHS); } @@ -301,37 +486,33 @@ class LLT { private: /// LLT is packed into 64 bits as follows: - /// isScalar : 1 - /// isPointer : 1 - /// isVector : 1 - /// with 61 bits remaining for Kind-specific data, packed in bitfields - /// as described below. As there isn't a simple portable way to pack bits - /// into bitfields, here the different fields in the packed structure is + /// Info : 4 + /// RawData : 60 + /// with 61 bits of RawData remaining for Kind-specific data, packed in + /// bitfields as described below. As there isn't a simple portable way to pack + /// bits into bitfields, here the different fields in the packed structure is /// described in static const *Field variables. Each of these variables /// is a 2-element array, with the first element describing the bitfield size /// and the second element describing the bitfield offset. /// - /// +--------+---------+--------+----------+----------------------+ - /// |isScalar|isPointer|isVector| RawData |Notes | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 0 | 0 |Invalid | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 1 | 0 |Tombstone Key | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 0 | 0 |Empty Key | - /// +--------+---------+--------+----------+----------------------+ - /// | 1 | 0 | 0 | 0 |Token | - /// +--------+---------+--------+----------+----------------------+ - /// | 1 | 0 | 0 | non-zero |Scalar | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 0 | non-zero |Pointer | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 0 | 1 | non-zero |Vector of non-pointer | - /// +--------+---------+--------+----------+----------------------+ - /// | 0 | 1 | 1 | non-zero |Vector of pointer | - /// +--------+---------+--------+----------+----------------------+ - /// - /// Everything else is reserved. + /* + --- LLT --- + + 63 56 47 39 31 23 15 7 0 + | | | | | | | | | + |xxxxxxxx|xxxxxxxx|xxxxxxxx|xxxxxxxx|xxxxxxxx|xxxxxxxx|xxxxxxxx|xxxxxxxx| + ........ ........ ........ ........ (1) + ******** ******** (2) + ~~~~~~~~ ~~~~~~~~ ~~~~~~~~ (3) + ^^^^^^^^ ^^^^^^^^ (4) + @ (5) + ### (6) + %%%% (7) + + (1) ScalarSize (2) PointerSize (3) PointerAddressSpace + (4) VectorElements (5) VectorScalable (6) FPVariant (7) Kind + + */ typedef int BitFieldInfo[2]; /// /// This is how the bitfields are packed per Kind: @@ -340,17 +521,20 @@ class LLT { /// valid encodings, SizeInBits/SizeOfElement must be larger than 0. /// * Non-pointer scalar (isPointer == 0 && isVector == 0): /// SizeInBits: 32; - static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 29}; + /// FPInfoField: 3; + static const constexpr BitFieldInfo ScalarSizeFieldInfo{32, 28}; + static const constexpr BitFieldInfo FPFieldInfo{3, 20}; /// * Pointer (isPointer == 1 && isVector == 0): /// SizeInBits: 16; /// AddressSpace: 24; - static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 45}; - static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 21}; + static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 44}; + static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{24, 20}; /// * Vector-of-non-pointer (isPointer == 0 && isVector == 1): /// NumElements: 16; /// SizeOfElement: 32; + /// FPInfoField: 3; /// Scalable: 1; - static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 5}; + static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 4}; static const constexpr BitFieldInfo VectorScalableFieldInfo{1, 0}; /// * Vector-of-pointer (isPointer == 1 && isVector == 1): /// NumElements: 16; @@ -358,10 +542,8 @@ class LLT { /// AddressSpace: 24; /// Scalable: 1; - uint64_t IsScalar : 1; - uint64_t IsPointer : 1; - uint64_t IsVector : 1; - uint64_t RawData : 61; + Kind Info : 4; + uint64_t RawData : 60; static constexpr uint64_t getMask(const BitFieldInfo FieldInfo) { const int FieldSizeInBits = FieldInfo[0]; @@ -381,21 +563,21 @@ class LLT { return getMask(FieldInfo) & (RawData >> FieldInfo[1]); } - constexpr void init(bool IsPointer, bool IsVector, bool IsScalar, - ElementCount EC, uint64_t SizeInBits, - unsigned AddressSpace) { + constexpr void init(Kind Info, ElementCount EC, uint64_t SizeInBits, + unsigned AddressSpace, FPVariant FP) { assert(SizeInBits <= std::numeric_limits::max() && "Not enough bits in LLT to represent size"); - this->IsPointer = IsPointer; - this->IsVector = IsVector; - this->IsScalar = IsScalar; - if (IsPointer) { + this->Info = Info; + if (Info == Kind::POINTER || Info == Kind::VECTOR_POINTER) { RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) | maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo); } else { - RawData = maskAndShift(SizeInBits, ScalarSizeFieldInfo); + RawData = maskAndShift(SizeInBits, ScalarSizeFieldInfo) | + maskAndShift((uint64_t)FP, FPFieldInfo); } - if (IsVector) { + + if (Info == Kind::VECTOR_ANY || Info == Kind::VECTOR_INTEGER || + Info == Kind::VECTOR_FLOAT || Info == Kind::VECTOR_POINTER) { RawData |= maskAndShift(EC.getKnownMinValue(), VectorElementsFieldInfo) | maskAndShift(EC.isScalable() ? 1 : 0, VectorScalableFieldInfo); } @@ -403,8 +585,7 @@ class LLT { public: constexpr uint64_t getUniqueRAWLLTData() const { - return ((uint64_t)RawData) << 3 | ((uint64_t)IsScalar) << 2 | - ((uint64_t)IsPointer) << 1 | ((uint64_t)IsVector); + return ((uint64_t)RawData) << 4 | ((uint64_t)Info); } }; @@ -413,15 +594,15 @@ inline raw_ostream& operator<<(raw_ostream &OS, const LLT &Ty) { return OS; } -template<> struct DenseMapInfo { +template <> struct DenseMapInfo { static inline LLT getEmptyKey() { LLT Invalid; - Invalid.IsPointer = true; + Invalid.Info = LLT::Kind::POINTER; return Invalid; } static inline LLT getTombstoneKey() { LLT Invalid; - Invalid.IsVector = true; + Invalid.Info = LLT::Kind::VECTOR_ANY; return Invalid; } static inline unsigned getHashValue(const LLT &Ty) { @@ -433,6 +614,6 @@ template<> struct DenseMapInfo { } }; -} +} // namespace llvm #endif // LLVM_CODEGEN_LOWLEVELTYPE_H diff --git a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h index bf4e490554723..a503199856c5a 100644 --- a/llvm/include/llvm/Target/TargetMachine.h +++ b/llvm/include/llvm/Target/TargetMachine.h @@ -295,6 +295,9 @@ class LLVM_ABI TargetMachine { bool getO0WantsFastISel() { return O0WantsFastISel; } void setO0WantsFastISel(bool Enable) { O0WantsFastISel = Enable; } void setGlobalISel(bool Enable) { Options.EnableGlobalISel = Enable; } + void setGlobalISelExtendedLLT(bool Enable) { + Options.EnableGlobalISelExtendedLLT = Enable; + } void setGlobalISelAbort(GlobalISelAbortMode Mode) { Options.GlobalISelAbort = Mode; } diff --git a/llvm/include/llvm/Target/TargetOptions.h b/llvm/include/llvm/Target/TargetOptions.h index db90f2e4cc7cc..b167d4e13a342 100644 --- a/llvm/include/llvm/Target/TargetOptions.h +++ b/llvm/include/llvm/Target/TargetOptions.h @@ -123,7 +123,8 @@ class TargetOptions { ApproxFuncFPMath(false), EnableAIXExtendedAltivecABI(false), HonorSignDependentRoundingFPMathOption(false), NoZerosInBSS(false), GuaranteedTailCallOpt(false), StackSymbolOrdering(true), - EnableFastISel(false), EnableGlobalISel(false), UseInitArray(false), + EnableFastISel(false), EnableGlobalISel(false), + EnableGlobalISelExtendedLLT(false), UseInitArray(false), DisableIntegratedAS(false), FunctionSections(false), DataSections(false), IgnoreXCOFFVisibility(false), XCOFFTracebackTable(true), UniqueSectionNames(true), @@ -235,6 +236,10 @@ class TargetOptions { /// EnableGlobalISel - This flag enables global instruction selection. unsigned EnableGlobalISel : 1; + /// EnableGlobalISelExtendedLLT - This flag enables LLTs with extenden type + /// info so target may distinguish different formats of equal sized scalars. + unsigned EnableGlobalISelExtendedLLT : 1; + /// EnableGlobalISelAbort - Control abort behaviour when global instruction /// selection fails to lower/select an instruction. GlobalISelAbortMode GlobalISelAbort = GlobalISelAbortMode::Enable; diff --git a/llvm/lib/CodeGen/Analysis.cpp b/llvm/lib/CodeGen/Analysis.cpp index 2ef96cc4400f7..61d509db7f14a 100644 --- a/llvm/lib/CodeGen/Analysis.cpp +++ b/llvm/lib/CodeGen/Analysis.cpp @@ -146,8 +146,8 @@ void llvm::ComputeValueVTs(const TargetLowering &TLI, const DataLayout &DL, } } -void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty, - SmallVectorImpl &ValueTys, +void llvm::computeValueLLTs(const TargetLowering &TLI, const DataLayout &DL, + Type &Ty, SmallVectorImpl &ValueTys, SmallVectorImpl *Offsets, uint64_t StartingOffset) { // Given a struct type, recursively traverse the elements. @@ -158,7 +158,7 @@ void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty, const StructLayout *SL = Offsets ? DL.getStructLayout(STy) : nullptr; for (unsigned I = 0, E = STy->getNumElements(); I != E; ++I) { uint64_t EltOffset = SL ? SL->getElementOffset(I) : 0; - computeValueLLTs(DL, *STy->getElementType(I), ValueTys, Offsets, + computeValueLLTs(TLI, DL, *STy->getElementType(I), ValueTys, Offsets, StartingOffset + EltOffset); } return; @@ -168,7 +168,7 @@ void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty, Type *EltTy = ATy->getElementType(); uint64_t EltSize = DL.getTypeAllocSize(EltTy).getFixedValue(); for (unsigned i = 0, e = ATy->getNumElements(); i != e; ++i) - computeValueLLTs(DL, *EltTy, ValueTys, Offsets, + computeValueLLTs(TLI, DL, *EltTy, ValueTys, Offsets, StartingOffset + i * EltSize); return; } @@ -176,7 +176,7 @@ void llvm::computeValueLLTs(const DataLayout &DL, Type &Ty, if (Ty.isVoidTy()) return; // Base case: we can get an LLT for this LLVM IR type. - ValueTys.push_back(getLLTForType(Ty, DL)); + ValueTys.push_back(TLI.getLLTForType(Ty, DL)); if (Offsets != nullptr) Offsets->push_back(StartingOffset * 8); } diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp index 90a18b86c1b1f..41609dae1b156 100644 --- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp @@ -159,7 +159,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB, if (const Function *F = dyn_cast(CalleeV)) { if (F->hasFnAttribute(Attribute::NonLazyBind)) { - LLT Ty = getLLTForType(*F->getType(), DL); + LLT Ty = TLI->getLLTForType(*F->getType(), DL); Register Reg = MIRBuilder.buildGlobalValue(Ty, F).getReg(0); Info.Callee = MachineOperand::CreateReg(Reg, false); } else { @@ -781,11 +781,11 @@ bool CallLowering::handleAssignments(ValueHandler &Handler, const MVT ValVT = VA.getValVT(); const MVT LocVT = VA.getLocVT(); - const LLT LocTy(LocVT); - const LLT ValTy(ValVT); + const LLT LocTy = TLI->getLLTForMVT(LocVT); + const LLT ValTy = TLI->getLLTForMVT(ValVT); const LLT NewLLT = Handler.isIncomingArgumentHandler() ? LocTy : ValTy; const EVT OrigVT = EVT::getEVT(Args[i].Ty); - const LLT OrigTy = getLLTForType(*Args[i].Ty, DL); + const LLT OrigTy = TLI->getLLTForType(*Args[i].Ty, DL); const LLT PointerTy = LLT::pointer( AllocaAddressSpace, DL.getPointerSizeInBits(AllocaAddressSpace)); @@ -1004,7 +1004,7 @@ void CallLowering::insertSRetLoads(MachineIRBuilder &MIRBuilder, Type *RetTy, Align BaseAlign = DL.getPrefTypeAlign(RetTy); Type *RetPtrTy = PointerType::get(RetTy->getContext(), DL.getAllocaAddrSpace()); - LLT OffsetLLTy = getLLTForType(*DL.getIndexType(RetPtrTy), DL); + LLT OffsetLLTy = TLI->getLLTForType(*DL.getIndexType(RetPtrTy), DL); MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(MF, FI); @@ -1035,7 +1035,8 @@ void CallLowering::insertSRetStores(MachineIRBuilder &MIRBuilder, Type *RetTy, unsigned NumValues = SplitVTs.size(); Align BaseAlign = DL.getPrefTypeAlign(RetTy); unsigned AS = DL.getAllocaAddrSpace(); - LLT OffsetLLTy = getLLTForType(*DL.getIndexType(RetTy->getContext(), AS), DL); + LLT OffsetLLTy = + TLI->getLLTForType(*DL.getIndexType(RetTy->getContext(), AS), DL); MachinePointerInfo PtrInfo(AS); diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 541269ab6bfce..8dac12fc4597a 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -192,7 +192,7 @@ IRTranslator::allocateVRegs(const Value &Val) { auto *Regs = VMap.getVRegs(Val); auto *Offsets = VMap.getOffsets(Val); SmallVector SplitTys; - computeValueLLTs(*DL, *Val.getType(), SplitTys, + computeValueLLTs(*TLI, *DL, *Val.getType(), SplitTys, Offsets->empty() ? Offsets : nullptr); for (unsigned i = 0; i < SplitTys.size(); ++i) Regs->push_back(0); @@ -216,7 +216,7 @@ ArrayRef IRTranslator::getOrCreateVRegs(const Value &Val) { "Don't know how to create an empty vreg"); SmallVector SplitTys; - computeValueLLTs(*DL, *Val.getType(), SplitTys, + computeValueLLTs(*TLI, *DL, *Val.getType(), SplitTys, Offsets->empty() ? Offsets : nullptr); if (!isa(Val)) { @@ -294,19 +294,9 @@ void IRTranslator::addMachineCFGPred(CFGEdge Edge, MachineBasicBlock *NewPred) { MachinePreds[Edge].push_back(NewPred); } -static bool containsBF16Type(const User &U) { - // BF16 cannot currently be represented by LLT, to avoid miscompiles we - // prevent any instructions using them. FIXME: This can be removed once LLT - // supports bfloat. - return U.getType()->getScalarType()->isBFloatTy() || - any_of(U.operands(), [](Value *V) { - return V->getType()->getScalarType()->isBFloatTy(); - }); -} - bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!mayTranslateUserTypes(U)) return false; // Get or create a virtual register for each value. @@ -328,7 +318,7 @@ bool IRTranslator::translateBinaryOp(unsigned Opcode, const User &U, bool IRTranslator::translateUnaryOp(unsigned Opcode, const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!mayTranslateUserTypes(U)) return false; Register Op0 = getOrCreateVReg(*U.getOperand(0)); @@ -348,7 +338,7 @@ bool IRTranslator::translateFNeg(const User &U, MachineIRBuilder &MIRBuilder) { bool IRTranslator::translateCompare(const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!mayTranslateUserTypes(U)) return false; auto *CI = cast(&U); @@ -857,7 +847,7 @@ void IRTranslator::emitJumpTable(SwitchCG::JumpTable &JT, MIB.setDebugLoc(CurBuilder->getDebugLoc()); Type *PtrIRTy = PointerType::getUnqual(MF->getFunction().getContext()); - const LLT PtrTy = getLLTForType(*PtrIRTy, *DL); + const LLT PtrTy = TLI->getLLTForType(*PtrIRTy, *DL); auto Table = MIB.buildJumpTable(PtrTy, JT.JTI); MIB.buildBrJT(Table.getReg(0), JT.JTI, JT.Reg); @@ -872,7 +862,7 @@ bool IRTranslator::emitJumpTableHeader(SwitchCG::JumpTable &JT, const Value &SValue = *JTH.SValue; // Subtract the lowest switch case value from the value being switched on. - const LLT SwitchTy = getLLTForType(*SValue.getType(), *DL); + const LLT SwitchTy = TLI->getLLTForType(*SValue.getType(), *DL); Register SwitchOpReg = getOrCreateVReg(SValue); auto FirstCst = MIB.buildConstant(SwitchTy, JTH.First); auto Sub = MIB.buildSub({SwitchTy}, SwitchOpReg, FirstCst); @@ -1105,7 +1095,7 @@ void IRTranslator::emitBitTestHeader(SwitchCG::BitTestBlock &B, auto RangeSub = MIB.buildSub(SwitchOpTy, SwitchOpReg, MinValReg); Type *PtrIRTy = PointerType::getUnqual(MF->getFunction().getContext()); - const LLT PtrTy = getLLTForType(*PtrIRTy, *DL); + const LLT PtrTy = TLI->getLLTForType(*PtrIRTy, *DL); LLT MaskTy = SwitchOpTy; if (MaskTy.getSizeInBits() > PtrTy.getSizeInBits() || @@ -1158,7 +1148,7 @@ void IRTranslator::emitBitTestCase(SwitchCG::BitTestBlock &BB, MachineIRBuilder &MIB = *CurBuilder; MIB.setMBB(*SwitchBB); - LLT SwitchTy = getLLTForMVT(BB.RegVT); + LLT SwitchTy = TLI->getLLTForMVT(BB.RegVT); Register Cmp; unsigned PopCount = llvm::popcount(B.Mask); if (PopCount == 1) { @@ -1385,7 +1375,7 @@ bool IRTranslator::translateLoad(const User &U, MachineIRBuilder &MIRBuilder) { const Value *Ptr = LI.getPointerOperand(); Type *OffsetIRTy = DL->getIndexType(Ptr->getType()); - LLT OffsetTy = getLLTForType(*OffsetIRTy, *DL); + LLT OffsetTy = TLI->getLLTForType(*OffsetIRTy, *DL); if (CLI->supportSwiftError() && isSwiftError(Ptr)) { assert(Regs.size() == 1 && "swifterror should be single pointer"); @@ -1432,7 +1422,7 @@ bool IRTranslator::translateStore(const User &U, MachineIRBuilder &MIRBuilder) { Register Base = getOrCreateVReg(*SI.getPointerOperand()); Type *OffsetIRTy = DL->getIndexType(SI.getPointerOperandType()); - LLT OffsetTy = getLLTForType(*OffsetIRTy, *DL); + LLT OffsetTy = TLI->getLLTForType(*OffsetIRTy, *DL); if (CLI->supportSwiftError() && isSwiftError(SI.getPointerOperand())) { assert(Vals.size() == 1 && "swifterror should be single pointer"); @@ -1554,8 +1544,8 @@ bool IRTranslator::translateCopy(const User &U, const Value &V, bool IRTranslator::translateBitCast(const User &U, MachineIRBuilder &MIRBuilder) { // If we're bitcasting to the source type, we can reuse the source vreg. - if (getLLTForType(*U.getOperand(0)->getType(), *DL) == - getLLTForType(*U.getType(), *DL)) { + if (TLI->getLLTForType(*U.getOperand(0)->getType(), *DL) == + TLI->getLLTForType(*U.getType(), *DL)) { // If the source is a ConstantInt then it was probably created by // ConstantHoisting and we should leave it alone. if (isa(U.getOperand(0))) @@ -1569,7 +1559,7 @@ bool IRTranslator::translateBitCast(const User &U, bool IRTranslator::translateCast(unsigned Opcode, const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!mayTranslateUserTypes(U)) return false; uint32_t Flags = 0; @@ -1587,9 +1577,9 @@ bool IRTranslator::translateGetElementPtr(const User &U, Value &Op0 = *U.getOperand(0); Register BaseReg = getOrCreateVReg(Op0); Type *PtrIRTy = Op0.getType(); - LLT PtrTy = getLLTForType(*PtrIRTy, *DL); + LLT PtrTy = TLI->getLLTForType(*PtrIRTy, *DL); Type *OffsetIRTy = DL->getIndexType(PtrIRTy); - LLT OffsetTy = getLLTForType(*OffsetIRTy, *DL); + LLT OffsetTy = TLI->getLLTForType(*OffsetIRTy, *DL); uint32_t PtrAddFlags = 0; // Each PtrAdd generated to implement the GEP inherits its nuw, nusw, inbounds @@ -1626,9 +1616,9 @@ bool IRTranslator::translateGetElementPtr(const User &U, BaseReg) .getReg(0); PtrIRTy = FixedVectorType::get(PtrIRTy, VectorWidth); - PtrTy = getLLTForType(*PtrIRTy, *DL); + PtrTy = TLI->getLLTForType(*PtrIRTy, *DL); OffsetIRTy = DL->getIndexType(PtrIRTy); - OffsetTy = getLLTForType(*OffsetIRTy, *DL); + OffsetTy = TLI->getLLTForType(*OffsetIRTy, *DL); } int64_t Offset = 0; @@ -1678,7 +1668,7 @@ bool IRTranslator::translateGetElementPtr(const User &U, Register GepOffsetReg; if (ElementSize != 1) { auto ElementSizeMIB = MIRBuilder.buildConstant( - getLLTForType(*OffsetIRTy, *DL), ElementSize); + TLI->getLLTForType(*OffsetIRTy, *DL), ElementSize); // The multiplication is NUW if the GEP is NUW and NSW if the GEP is // NUSW. @@ -2335,7 +2325,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID, MIRBuilder.buildFMA(Dst, Op0, Op1, Op2, MachineInstr::copyFlagsFromInstruction(CI)); } else { - LLT Ty = getLLTForType(*CI.getType(), *DL); + LLT Ty = TLI->getLLTForType(*CI.getType(), *DL); auto FMul = MIRBuilder.buildFMul( Ty, Op0, Op1, MachineInstr::copyFlagsFromInstruction(CI)); MIRBuilder.buildFAdd(Dst, FMul, Op2, @@ -2402,7 +2392,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID, getStackGuard(getOrCreateVReg(CI), MIRBuilder); return true; case Intrinsic::stackprotector: { - LLT PtrTy = getLLTForType(*CI.getArgOperand(0)->getType(), *DL); + LLT PtrTy = TLI->getLLTForType(*CI.getArgOperand(0)->getType(), *DL); Register GuardVal; if (TLI->useLoadStackGuardNode(*CI.getModule())) { GuardVal = MRI->createGenericVirtualRegister(PtrTy); @@ -2649,7 +2639,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID, case Intrinsic::vector_deinterleave2: { // Both intrinsics have at least one operand. Value *Op0 = CI.getOperand(0); - LLT ResTy = getLLTForType(*Op0->getType(), MIRBuilder.getDataLayout()); + LLT ResTy = TLI->getLLTForType(*Op0->getType(), MIRBuilder.getDataLayout()); if (!ResTy.isFixedVector()) return false; @@ -2674,7 +2664,7 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID, bool IRTranslator::translateInlineAsm(const CallBase &CB, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(CB)) + if (!mayTranslateUserTypes(CB)) return false; const InlineAsmLowering *ALI = MF->getSubtarget().getInlineAsmLowering(); @@ -2699,7 +2689,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB, for (const auto &Arg : CB.args()) { if (CLI->supportSwiftError() && isSwiftError(Arg)) { assert(SwiftInVReg == 0 && "Expected only one swift error argument"); - LLT Ty = getLLTForType(*Arg->getType(), *DL); + LLT Ty = TLI->getLLTForType(*Arg->getType(), *DL); SwiftInVReg = MRI->createGenericVirtualRegister(Ty); MIRBuilder.buildCopy(SwiftInVReg, SwiftError.getOrCreateVRegUseAt( &CB, &MIRBuilder.getMBB(), Arg)); @@ -2765,7 +2755,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB, } bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!mayTranslateUserTypes(U)) return false; const CallInst &CI = cast(U); @@ -2851,7 +2841,7 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { Align Alignment = Info.align.value_or( DL->getABITypeAlign(Info.memVT.getTypeForEVT(F->getContext()))); LLT MemTy = Info.memVT.isSimple() - ? getLLTForMVT(Info.memVT.getSimpleVT()) + ? TLI->getLLTForMVT(Info.memVT.getSimpleVT()) : LLT::scalar(Info.memVT.getStoreSizeInBits()); // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic @@ -3059,13 +3049,13 @@ bool IRTranslator::translateLandingPad(const User &U, if (auto *RegMask = TRI.getCustomEHPadPreservedMask(*MF)) MF->getRegInfo().addPhysRegsUsedFromRegMask(RegMask); - LLT Ty = getLLTForType(*LP.getType(), *DL); + LLT Ty = TLI->getLLTForType(*LP.getType(), *DL); Register Undef = MRI->createGenericVirtualRegister(Ty); MIRBuilder.buildUndef(Undef); SmallVector Tys; for (Type *Ty : cast(LP.getType())->elements()) - Tys.push_back(getLLTForType(*Ty, *DL)); + Tys.push_back(TLI->getLLTForType(*Ty, *DL)); assert(Tys.size() == 2 && "Only two-valued landingpads are supported"); // Mark exception register as live in. @@ -3110,7 +3100,7 @@ bool IRTranslator::translateAlloca(const User &U, // Now we're in the harder dynamic case. Register NumElts = getOrCreateVReg(*AI.getArraySize()); Type *IntPtrIRTy = DL->getIntPtrType(AI.getType()); - LLT IntPtrTy = getLLTForType(*IntPtrIRTy, *DL); + LLT IntPtrTy = TLI->getLLTForType(*IntPtrIRTy, *DL); if (MRI->getType(NumElts) != IntPtrTy) { Register ExtElts = MRI->createGenericVirtualRegister(IntPtrTy); MIRBuilder.buildZExtOrTrunc(ExtElts, NumElts); @@ -3391,7 +3381,7 @@ bool IRTranslator::translateAtomicCmpXchg(const User &U, bool IRTranslator::translateAtomicRMW(const User &U, MachineIRBuilder &MIRBuilder) { - if (containsBF16Type(U)) + if (!mayTranslateUserTypes(U)) return false; const AtomicRMWInst &I = cast(U); @@ -3732,6 +3722,20 @@ bool IRTranslator::translate(const Constant &C, Register Reg) { return true; } +bool IRTranslator::mayTranslateUserTypes(const User &U) const { + if (TLI->getTargetMachine().Options.EnableGlobalISelExtendedLLT) { + return true; + } + + // BF16 cannot currently be represented by default LLT, to avoid miscompiles + // we prevent any instructions using them in targets with disabled + // TargetOptions::EnableGlobalISelExtendedLLT. + return !U.getType()->getScalarType()->isBFloatTy() && + !any_of(U.operands(), [](Value *V) { + return V->getType()->getScalarType()->isBFloatTy(); + }); +} + bool IRTranslator::finalizeBasicBlock(const BasicBlock &BB, MachineBasicBlock &MBB) { for (auto &BTB : SL->BitTestCases) { @@ -3853,8 +3857,8 @@ bool IRTranslator::emitSPDescriptorParent(StackProtectorDescriptor &SPD, CurBuilder->setInsertPt(*ParentBB, ParentBB->end()); // First create the loads to the guard/stack slot for the comparison. Type *PtrIRTy = PointerType::getUnqual(MF->getFunction().getContext()); - const LLT PtrTy = getLLTForType(*PtrIRTy, *DL); - LLT PtrMemTy = getLLTForMVT(TLI->getPointerMemTy(*DL)); + const LLT PtrTy = TLI->getLLTForType(*PtrIRTy, *DL); + LLT PtrMemTy = TLI->getLLTForMVT(TLI->getPointerMemTy(*DL)); MachineFrameInfo &MFI = ParentBB->getParent()->getFrameInfo(); int FI = MFI.getStackProtectorIndex(); diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp index 25c1db91b05d8..ded4df4edc14c 100644 --- a/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizeMutations.cpp @@ -55,12 +55,10 @@ LegalizeMutation LegalizeMutations::changeElementCountTo(unsigned TypeIdx, } LegalizeMutation LegalizeMutations::changeElementCountTo(unsigned TypeIdx, - LLT NewEltTy) { + ElementCount EC) { return [=](const LegalityQuery &Query) { const LLT OldTy = Query.Types[TypeIdx]; - ElementCount NewEltCount = NewEltTy.isVector() ? NewEltTy.getElementCount() - : ElementCount::getFixed(1); - return std::make_pair(TypeIdx, OldTy.changeElementCount(NewEltCount)); + return std::make_pair(TypeIdx, OldTy.changeElementCount(EC)); }; } diff --git a/llvm/lib/CodeGen/GlobalISel/LoadStoreOpt.cpp b/llvm/lib/CodeGen/GlobalISel/LoadStoreOpt.cpp index b2f84351c9a95..b1e134ce5ef9e 100644 --- a/llvm/lib/CodeGen/GlobalISel/LoadStoreOpt.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LoadStoreOpt.cpp @@ -952,7 +952,7 @@ void LoadStoreOpt::initializeStoreMergeTargetInfo(unsigned AddrSpace) { const auto &LI = *MF->getSubtarget().getLegalizerInfo(); const auto &DL = MF->getFunction().getDataLayout(); Type *IRPtrTy = PointerType::get(MF->getFunction().getContext(), AddrSpace); - LLT PtrTy = getLLTForType(*IRPtrTy, DL); + LLT PtrTy = TLI->getLLTForType(*IRPtrTy, DL); // We assume that we're not going to be generating any stores wider than // MaxStoreSizeToForm bits for now. for (unsigned Size = 2; Size <= MaxStoreSizeToForm; Size *= 2) { diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp index 27df7e369436a..211809eacf4a7 100644 --- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp @@ -581,7 +581,8 @@ MachineInstrBuilder MachineIRBuilder::buildExtOrTrunc(unsigned ExtOpc, Op.getLLTTy(*getMRI()).getSizeInBits()) Opcode = TargetOpcode::G_TRUNC; else - assert(Res.getLLTTy(*getMRI()) == Op.getLLTTy(*getMRI())); + assert(Res.getLLTTy(*getMRI()).getSizeInBits() == + Op.getLLTTy(*getMRI()).getSizeInBits()); return buildInstr(Opcode, Res, Op); } diff --git a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp index 936c9fbb2fff0..6a15b1748630d 100644 --- a/llvm/lib/CodeGen/LowLevelTypeUtils.cpp +++ b/llvm/lib/CodeGen/LowLevelTypeUtils.cpp @@ -17,16 +17,16 @@ #include "llvm/IR/DerivedTypes.h" using namespace llvm; -LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { - if (auto VTy = dyn_cast(&Ty)) { +LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL, bool AllowExtendedLLT) { + if (auto *VTy = dyn_cast(&Ty)) { auto EC = VTy->getElementCount(); - LLT ScalarTy = getLLTForType(*VTy->getElementType(), DL); + LLT ScalarTy = getLLTForType(*VTy->getElementType(), DL, AllowExtendedLLT); if (EC.isScalar()) return ScalarTy; return LLT::vector(EC, ScalarTy); } - if (auto PTy = dyn_cast(&Ty)) { + if (auto *PTy = dyn_cast(&Ty)) { unsigned AddrSpace = PTy->getAddressSpace(); return LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace)); } @@ -36,6 +36,41 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { // concerned. auto SizeInBits = DL.getTypeSizeInBits(&Ty); assert(SizeInBits != 0 && "invalid zero-sized type"); + + // Return simple scalar + if (!AllowExtendedLLT) + return LLT::scalar(SizeInBits); + + // Choose more precise LLT variant + if (Ty.isFloatingPointTy()) { + if (Ty.isHalfTy()) + return LLT::float16(); + + if (Ty.isBFloatTy()) + return LLT::bfloat16(); + + if (Ty.isFloatTy()) + return LLT::float32(); + + if (Ty.isDoubleTy()) + return LLT::float64(); + + if (Ty.isX86_FP80Ty()) + return LLT::x86fp80(); + + if (Ty.isFP128Ty()) + return LLT::float128(); + + if (Ty.isPPC_FP128Ty()) + return LLT::ppcf128(); + + llvm_unreachable("Unhandled LLVM IR floating point type"); + } + + if (Ty.isIntegerTy()) { + return LLT::integer(SizeInBits); + } + return LLT::scalar(SizeInBits); } @@ -46,12 +81,25 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { } MVT llvm::getMVTForLLT(LLT Ty) { - if (!Ty.isVector()) - return MVT::getIntegerVT(Ty.getSizeInBits()); + if (Ty.isVector()) { + return MVT::getVectorVT(getMVTForLLT(Ty.getElementType()), + Ty.getElementCount()); + } + + if (Ty.isFloat()) { + if (Ty == LLT::bfloat16()) + return MVT::bf16; - return MVT::getVectorVT( - MVT::getIntegerVT(Ty.getElementType().getSizeInBits()), - Ty.getElementCount()); + if (Ty == LLT::x86fp80()) + return MVT::f80; + + if (Ty == LLT::ppcf128()) + return MVT::ppcf128; + + return MVT::getFloatingPointVT(Ty.getSizeInBits()); + } + + return MVT::getIntegerVT(Ty.getSizeInBits()); } EVT llvm::getApproximateEVTForLLT(LLT Ty, LLVMContext &Ctx) { @@ -63,16 +111,23 @@ EVT llvm::getApproximateEVTForLLT(LLT Ty, LLVMContext &Ctx) { return EVT::getIntegerVT(Ctx, Ty.getSizeInBits()); } -LLT llvm::getLLTForMVT(MVT Ty) { - if (!Ty.isVector()) - return LLT::scalar(Ty.getSizeInBits()); - - return LLT::scalarOrVector(Ty.getVectorElementCount(), - Ty.getVectorElementType().getSizeInBits()); +LLT llvm::getLLTForMVT(MVT VT, bool AllowExtendedLLT) { + return LLT(VT, AllowExtendedLLT); } const llvm::fltSemantics &llvm::getFltSemanticForLLT(LLT Ty) { - assert(Ty.isScalar() && "Expected a scalar type."); + assert((Ty.isAnyScalar() || Ty.isFloat()) && + "Expected a any scalar or float type."); + + if (Ty.isBFloat(16)) + return APFloat::BFloat(); + if (Ty.isX86FP80()) + return APFloat::x87DoubleExtended(); + if (Ty.isPPCF128()) + return APFloat::PPCDoubleDouble(); + + assert(!Ty.isVariantFloat() && "Unhandled variant float type"); + switch (Ty.getSizeInBits()) { case 16: return APFloat::IEEEhalf(); diff --git a/llvm/lib/CodeGen/MIRParser/MIParser.cpp b/llvm/lib/CodeGen/MIRParser/MIParser.cpp index 6a464d9dd6886..94205698c5038 100644 --- a/llvm/lib/CodeGen/MIRParser/MIParser.cpp +++ b/llvm/lib/CodeGen/MIRParser/MIParser.cpp @@ -1929,26 +1929,35 @@ static bool verifyAddrSpace(uint64_t AddrSpace) { } bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) { - if (Token.range().front() == 's' || Token.range().front() == 'p') { - StringRef SizeStr = Token.range().drop_front(); - if (SizeStr.size() == 0 || !llvm::all_of(SizeStr, isdigit)) - return error("expected integers after 's'/'p' type character"); - } - - if (Token.range().front() == 's') { - auto ScalarSize = APSInt(Token.range().drop_front()).getZExtValue(); - if (ScalarSize) { - if (!verifyScalarSize(ScalarSize)) - return error("invalid size for scalar type"); - Ty = LLT::scalar(ScalarSize); - } else { + StringRef TypeDigits = Token.range(); + if (TypeDigits.consume_front("s") || TypeDigits.consume_front("i") || + TypeDigits.consume_front("f") || TypeDigits.consume_front("p") || + TypeDigits.consume_front("bf")) { + if (TypeDigits.empty() || !llvm::all_of(TypeDigits, isdigit)) + return error( + "expected integers after 's'/'i'/'f'/'bf'/'p' type identifier"); + } + + if (Token.range().starts_with("s") || Token.range().starts_with("i")) { + bool ScalarOrInt = Token.range().starts_with("s"); + auto ScalarSize = APSInt(TypeDigits).getZExtValue(); + if (!ScalarSize) { Ty = LLT::token(); + lex(); + return false; } + + if (!verifyScalarSize(ScalarSize)) + return error("invalid size for scalar type"); + + Ty = ScalarOrInt ? LLT::scalar(ScalarSize) : LLT::integer(ScalarSize); lex(); return false; - } else if (Token.range().front() == 'p') { + } + + if (Token.range().starts_with("p")) { const DataLayout &DL = MF.getDataLayout(); - uint64_t AS = APSInt(Token.range().drop_front()).getZExtValue(); + uint64_t AS = APSInt(TypeDigits).getZExtValue(); if (!verifyAddrSpace(AS)) return error("invalid address space number"); @@ -1957,10 +1966,28 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) { return false; } + if (Token.range().starts_with("f") || Token.range().starts_with("bf")) { + LLT::FPVariant FPVariant; + if (Token.range().starts_with("f")) { + FPVariant = LLT::FPVariant::IEEE_FLOAT; + } else if (Token.range().starts_with("bf")) { + FPVariant = LLT::FPVariant::BRAIN_FLOAT; + } else { + return error("unknown floating point type identifier"); + } + auto ScalarSize = APSInt(TypeDigits).getZExtValue(); + if (!ScalarSize || !verifyScalarSize(ScalarSize)) + return error("invalid size for scalar type"); + Ty = LLT::floatingPoint(ScalarSize, FPVariant); + lex(); + return false; + } + // Now we're looking for a vector. if (Token.isNot(MIToken::less)) - return error(Loc, "expected sN, pA, , , , " - "or for GlobalISel type"); + return error(Loc, "expected tN, pA, , , , " + "or for GlobalISel type, " + "where t = {'s', 'i', 'f', 'bf'}"); lex(); bool HasVScale = @@ -1968,15 +1995,17 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) { if (HasVScale) { lex(); if (Token.isNot(MIToken::Identifier) || Token.stringValue() != "x") - return error("expected or "); + return error( + "expected , where t = {'s', 'i', 'f', 'bf', 'p'}"); lex(); } auto GetError = [this, &HasVScale, Loc]() { if (HasVScale) - return error( - Loc, "expected or for vector type"); - return error(Loc, "expected or for vector type"); + return error(Loc, "expected for vector type, where t = " + "{'s', 'i', 'f', 'bf', 'p'}"); + return error(Loc, "expected for vector type, where t = {'s', 'i', " + "'f', 'bf', 'p'}"); }; if (Token.isNot(MIToken::IntegerLiteral)) @@ -1991,25 +2020,40 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) { return GetError(); lex(); - if (Token.range().front() != 's' && Token.range().front() != 'p') + StringRef VectorTyDigits = Token.range(); + if (!VectorTyDigits.consume_front("s") && + !VectorTyDigits.consume_front("i") && + !VectorTyDigits.consume_front("f") && + !VectorTyDigits.consume_front("p") && !VectorTyDigits.consume_front("bf")) return GetError(); - StringRef SizeStr = Token.range().drop_front(); - if (SizeStr.size() == 0 || !llvm::all_of(SizeStr, isdigit)) - return error("expected integers after 's'/'p' type character"); + if (VectorTyDigits.empty() || !llvm::all_of(VectorTyDigits, isdigit)) + return error( + "expected integers after 's'/'i'/'f'/'bf'/'p' type identifier"); - if (Token.range().front() == 's') { - auto ScalarSize = APSInt(Token.range().drop_front()).getZExtValue(); + if (Token.range().starts_with("s") || Token.range().starts_with("i")) { + bool ScalarOrInt = Token.range().starts_with("s"); + auto ScalarSize = APSInt(VectorTyDigits).getZExtValue(); if (!verifyScalarSize(ScalarSize)) return error("invalid size for scalar element in vector"); - Ty = LLT::scalar(ScalarSize); - } else if (Token.range().front() == 'p') { + Ty = ScalarOrInt ? LLT::scalar(ScalarSize) : LLT::integer(ScalarSize); + } else if (Token.range().starts_with("p")) { const DataLayout &DL = MF.getDataLayout(); - uint64_t AS = APSInt(Token.range().drop_front()).getZExtValue(); + uint64_t AS = APSInt(VectorTyDigits).getZExtValue(); if (!verifyAddrSpace(AS)) return error("invalid address space number"); Ty = LLT::pointer(AS, DL.getPointerSizeInBits(AS)); + } else if (Token.range().starts_with("f")) { + auto ScalarSize = APSInt(VectorTyDigits).getZExtValue(); + if (!verifyScalarSize(ScalarSize)) + return error("invalid size for float element in vector"); + Ty = LLT::floatingPoint(ScalarSize, LLT::FPVariant::IEEE_FLOAT); + } else if (Token.range().starts_with("bf")) { + auto ScalarSize = APSInt(VectorTyDigits).getZExtValue(); + if (!verifyScalarSize(ScalarSize)) + return error("invalid size for bfloat element in vector"); + Ty = LLT::floatingPoint(ScalarSize, LLT::FPVariant::BRAIN_FLOAT); } else return GetError(); lex(); @@ -2025,14 +2069,15 @@ bool MIParser::parseLowLevelType(StringRef::iterator Loc, LLT &Ty) { bool MIParser::parseTypedImmediateOperand(MachineOperand &Dest) { assert(Token.is(MIToken::Identifier)); - StringRef TypeStr = Token.range(); - if (TypeStr.front() != 'i' && TypeStr.front() != 's' && - TypeStr.front() != 'p') + StringRef TypeDigits = Token.range(); + if (!TypeDigits.consume_front("i") && !TypeDigits.consume_front("s") && + !TypeDigits.consume_front("p") && !TypeDigits.consume_front("f") && + !TypeDigits.consume_front("bf")) + return error("a typed immediate operand should start with one of 'i', " + "'s', 'f', 'bf', or 'p'"); + if (TypeDigits.empty() || !llvm::all_of(TypeDigits, isdigit)) return error( - "a typed immediate operand should start with one of 'i', 's', or 'p'"); - StringRef SizeStr = Token.range().drop_front(); - if (SizeStr.size() == 0 || !llvm::all_of(SizeStr, isdigit)) - return error("expected integers after 'i'/'s'/'p' type character"); + "expected integers after 'i'/'s'/'f'/'bf'/'p' type identifier"); auto Loc = Token.location(); lex(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index ece50ed95fc49..ca75c42ee0a72 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -2457,7 +2457,7 @@ void SelectionDAGISel::Select_READ_REGISTER(SDNode *Op) { const MDString *RegStr = cast(MD->getMD()->getOperand(0)); EVT VT = Op->getValueType(0); - LLT Ty = VT.isSimple() ? getLLTForMVT(VT.getSimpleVT()) : LLT(); + LLT Ty = VT.isSimple() ? TLI->getLLTForMVT(VT.getSimpleVT()) : LLT(); const MachineFunction &MF = CurDAG->getMachineFunction(); Register Reg = TLI->getRegisterByName(RegStr->getString().data(), Ty, MF); @@ -2488,7 +2488,7 @@ void SelectionDAGISel::Select_WRITE_REGISTER(SDNode *Op) { const MDString *RegStr = cast(MD->getMD()->getOperand(0)); EVT VT = Op->getOperand(2).getValueType(); - LLT Ty = VT.isSimple() ? getLLTForMVT(VT.getSimpleVT()) : LLT(); + LLT Ty = VT.isSimple() ? TLI->getLLTForMVT(VT.getSimpleVT()) : LLT(); const MachineFunction &MF = CurDAG->getMachineFunction(); Register Reg = TLI->getRegisterByName(RegStr->getString().data(), Ty, MF); diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp index 9ffced80b07f1..3233435c0828b 100644 --- a/llvm/lib/CodeGen/TargetLoweringBase.cpp +++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -978,6 +978,14 @@ EVT TargetLoweringBase::getShiftAmountTy(EVT LHSTy, return ShiftVT; } +LLT TargetLoweringBase::getLLTForType(Type &Ty, const DataLayout &DL) const { + return llvm::getLLTForType(Ty, DL, TM.Options.EnableGlobalISelExtendedLLT); +} + +LLT TargetLoweringBase::getLLTForMVT(MVT Ty) const { + return llvm::getLLTForMVT(Ty, TM.Options.EnableGlobalISelExtendedLLT); +} + bool TargetLoweringBase::canOpTrap(unsigned Op, EVT VT) const { assert(isTypeLegal(VT)); switch (Op) { diff --git a/llvm/lib/CodeGenTypes/LowLevelType.cpp b/llvm/lib/CodeGenTypes/LowLevelType.cpp index 4785f2652b00e..6025f3fedafa1 100644 --- a/llvm/lib/CodeGenTypes/LowLevelType.cpp +++ b/llvm/lib/CodeGenTypes/LowLevelType.cpp @@ -16,36 +16,86 @@ #include "llvm/Support/raw_ostream.h" using namespace llvm; -LLT::LLT(MVT VT) { +static std::optional deriveFPInfo(MVT VT) { + if (!VT.isFloatingPoint()) + return std::nullopt; + + switch (VT.getScalarType().SimpleTy) { + case MVT::bf16: + return LLT::FPVariant::BRAIN_FLOAT; + case MVT::f80: + return LLT::FPVariant::EXTENDED_FP80; + case MVT::ppcf128: + return LLT::FPVariant::PPC128_FLOAT; + default: + return LLT::FPVariant::IEEE_FLOAT; + } +} + +LLT::LLT(MVT VT, bool AllowExtendedLLT) { + if (!AllowExtendedLLT) { + if (VT.isVector()) { + bool AsVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector(); + Kind Info = AsVector ? Kind::VECTOR_ANY : Kind::ANY_SCALAR; + init(Info, VT.getVectorElementCount(), + VT.getVectorElementType().getSizeInBits(), 0, FPVariant::IEEE_FLOAT); + } else if (VT.isValid() && !VT.isScalableTargetExtVT()) { + init(Kind::ANY_SCALAR, ElementCount::getFixed(0), VT.getSizeInBits(), 0, + FPVariant::IEEE_FLOAT); + } else { + this->Info = Kind::INVALID; + this->RawData = 0; + } + return; + } + + auto FP = deriveFPInfo(VT); + bool AsVector = VT.isVector() && + (VT.getVectorMinNumElements() > 1 || VT.isScalableVector()); + + LLT::Kind Info; + if (FP.has_value()) + Info = AsVector ? LLT::Kind::VECTOR_FLOAT : LLT::Kind::FLOAT; + else + Info = AsVector ? LLT::Kind::VECTOR_INTEGER : LLT::Kind::INTEGER; + if (VT.isVector()) { - bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector(); - init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector, - VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(), - /*AddressSpace=*/0); + init(Info, VT.getVectorElementCount(), + VT.getVectorElementType().getSizeInBits(), 0, + FP.value_or(LLT::FPVariant::IEEE_FLOAT)); } else if (VT.isValid() && !VT.isScalableTargetExtVT()) { // Aggregates are no different from real scalars as far as GlobalISel is // concerned. - init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true, - ElementCount::getFixed(0), VT.getSizeInBits(), /*AddressSpace=*/0); + init(Info, ElementCount::getFixed(0), VT.getSizeInBits(), 0, + FP.value_or(LLT::FPVariant::IEEE_FLOAT)); } else { - IsScalar = false; - IsPointer = false; - IsVector = false; - RawData = 0; + this->Info = Kind::INVALID; + this->RawData = 0; } + return; } void LLT::print(raw_ostream &OS) const { if (isVector()) { OS << "<"; OS << getElementCount() << " x " << getElementType() << ">"; - } else if (isPointer()) + } else if (isPointer()) { OS << "p" << getAddressSpace(); - else if (isValid()) { + } else if (isBFloat(16)) { + OS << "bf16"; + } else if (isPPCF128()) { + OS << "ppcf128"; + } else if (isFloat()) { + assert(!isVariantFloat() && "unknown float variant"); + OS << "f" << getScalarSizeInBits(); + } else if (isInteger()) { + OS << "i" << getScalarSizeInBits(); + } else if (isValid()) { assert(isScalar() && "unexpected type"); OS << "s" << getScalarSizeInBits(); - } else + } else { OS << "LLT_invalid"; + } } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -56,6 +106,7 @@ LLVM_DUMP_METHOD void LLT::dump() const { #endif const constexpr LLT::BitFieldInfo LLT::ScalarSizeFieldInfo; +const constexpr LLT::BitFieldInfo LLT::FPFieldInfo; const constexpr LLT::BitFieldInfo LLT::PointerSizeFieldInfo; const constexpr LLT::BitFieldInfo LLT::PointerAddressSpaceFieldInfo; const constexpr LLT::BitFieldInfo LLT::VectorElementsFieldInfo; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 79bef76cf4c4f..c897a0f65084f 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -457,7 +457,7 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, ExtendOp = TargetOpcode::G_ZEXT; LLT NewLLT(NewVT); - LLT OldLLT = getLLTForType(*CurArgInfo.Ty, DL); + LLT OldLLT = TLI.getLLTForType(*CurArgInfo.Ty, DL); CurArgInfo.Ty = EVT(NewVT).getTypeForEVT(Ctx); // Instead of an extend, we might have a vector type which needs // padding with more elements, e.g. <2 x half> -> <4 x half>. @@ -1397,7 +1397,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, // is set. if (Info.Callee.isSymbol() && F.getParent()->getRtLibUseGOT()) { auto MIB = MIRBuilder.buildInstr(TargetOpcode::G_GLOBAL_VALUE); - DstOp(getLLTForType(*F.getType(), DL)).addDefToMIB(MRI, MIB); + DstOp(TLI.getLLTForType(*F.getType(), DL)).addDefToMIB(MRI, MIB); MIB.addExternalSymbol(Info.Callee.getSymbolName(), AArch64II::MO_GOT); Info.Callee = MachineOperand::CreateReg(MIB.getReg(0), false); } diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp index 1b919abd222e4..29d5cec49d579 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp @@ -570,7 +570,8 @@ bool AArch64RegisterBankInfo::onlyUsesFP(const MachineInstr &MI, case Intrinsic::aarch64_neon_fcvtpu: // Force FPR register bank for half types, as those types otherwise // don't get legalized correctly resulting in fp16 <-> gpr32 COPY's. - return MRI.getType(MI.getOperand(2).getReg()) == LLT::float16(); + return MRI.getType(MI.getOperand(2).getReg()) == + LLT::scalar(16); // TODO: Expected LLT::float16() default: break; } diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp index d1a5b4e85da46..b5ef1a87a5abb 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp @@ -332,7 +332,7 @@ bool AMDGPUCallLowering::lowerReturnVal(MachineIRBuilder &B, extOpcodeToISDExtOpcode(ExtendOp)); if (ExtVT != VT) { RetInfo.Ty = ExtVT.getTypeForEVT(Ctx); - LLT ExtTy = getLLTForType(*RetInfo.Ty, DL); + LLT ExtTy = TLI.getLLTForType(*RetInfo.Ty, DL); Reg = B.buildInstr(ExtendOp, {ExtTy}, {Reg}).getReg(0); } } @@ -414,6 +414,7 @@ void AMDGPUCallLowering::lowerParameter(MachineIRBuilder &B, ArgInfo &OrigArg, MachineFunction &MF = B.getMF(); const Function &F = MF.getFunction(); const DataLayout &DL = F.getDataLayout(); + const SITargetLowering &TLI = *getTLI(); MachinePointerInfo PtrInfo(AMDGPUAS::CONSTANT_ADDRESS); LLT PtrTy = LLT::pointer(AMDGPUAS::CONSTANT_ADDRESS, 64); @@ -427,7 +428,7 @@ void AMDGPUCallLowering::lowerParameter(MachineIRBuilder &B, ArgInfo &OrigArg, Register PtrReg = B.getMRI()->createGenericVirtualRegister(PtrTy); lowerParameterPtr(PtrReg, B, Offset + FieldOffsets[Idx]); - LLT ArgTy = getLLTForType(*SplitArg.Ty, DL); + LLT ArgTy = TLI.getLLTForType(*SplitArg.Ty, DL); if (SplitArg.Flags[0].isPointer()) { // Compensate for losing pointeriness in splitValueTypes. LLT PtrTy = LLT::pointer(SplitArg.Flags[0].getPointerAddrSpace(), diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp index d23521c87e202..853f55278c667 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCombinerHelper.cpp @@ -464,8 +464,9 @@ bool AMDGPUCombinerHelper::matchCombineFmulWithSelectToFldexp( LLT DestTy = MRI.getType(Dst); LLT ScalarDestTy = DestTy.getScalarType(); - if ((ScalarDestTy != LLT::float64() && ScalarDestTy != LLT::float32() && - ScalarDestTy != LLT::float16()) || + // TODO: Expected float type in ScalarDestTy + if ((ScalarDestTy != LLT::scalar(64) && ScalarDestTy != LLT::scalar(32) && + ScalarDestTy != LLT::scalar(16)) || !MRI.hasOneNonDBGUse(Sel.getOperand(0).getReg())) return false; @@ -486,7 +487,8 @@ bool AMDGPUCombinerHelper::matchCombineFmulWithSelectToFldexp( return false; // For f32, only non-inline constants should be transformed. - if (ScalarDestTy == LLT::float32() && TII.isInlineConstant(*SelectTrueVal) && + // TODO: Expected float32 + if (ScalarDestTy == LLT::scalar(32) && TII.isInlineConstant(*SelectTrueVal) && TII.isInlineConstant(*SelectFalseVal)) return false; diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp index 600a13096f55d..a4e2f321be91d 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -295,9 +295,9 @@ constexpr LLT S1 = LLT::scalar(1); constexpr LLT S8 = LLT::scalar(8); constexpr LLT S16 = LLT::scalar(16); constexpr LLT S32 = LLT::scalar(32); -constexpr LLT F32 = LLT::float32(); +constexpr LLT F32 = LLT::scalar(32); // TODO: Expected float32 constexpr LLT S64 = LLT::scalar(64); -constexpr LLT F64 = LLT::float64(); +constexpr LLT F64 = LLT::scalar(64); // TODO: Expected float64 constexpr LLT S96 = LLT::scalar(96); constexpr LLT S128 = LLT::scalar(128); constexpr LLT S160 = LLT::scalar(160); @@ -317,8 +317,9 @@ constexpr LLT V10S16 = LLT::fixed_vector(10, 16); constexpr LLT V12S16 = LLT::fixed_vector(12, 16); constexpr LLT V16S16 = LLT::fixed_vector(16, 16); -constexpr LLT V2F16 = LLT::fixed_vector(2, LLT::float16()); -constexpr LLT V2BF16 = V2F16; // FIXME +constexpr LLT V2F16 = + LLT::fixed_vector(2, LLT::scalar(16)); // TODO: Expected float16 +constexpr LLT V2BF16 = V2F16; // FIXME constexpr LLT V2S32 = LLT::fixed_vector(2, 32); constexpr LLT V3S32 = LLT::fixed_vector(3, 32); @@ -3344,11 +3345,12 @@ bool AMDGPULegalizerInfo::legalizeFMad( const SIMachineFunctionInfo *MFI = MF.getInfo(); // TODO: Always legal with future ftz flag. + // TODO: Type is expected to be float32/float16 // FIXME: Do we need just output? - if (Ty == LLT::float32() && + if (Ty == LLT::scalar(32) && MFI->getMode().FP32Denormals == DenormalMode::getPreserveSign()) return true; - if (Ty == LLT::float16() && + if (Ty == LLT::scalar(16) && MFI->getMode().FP64FP16Denormals == DenormalMode::getPreserveSign()) return true; @@ -3900,8 +3902,8 @@ bool AMDGPULegalizerInfo::legalizeFPow(MachineInstr &MI, Register Src1 = MI.getOperand(2).getReg(); unsigned Flags = MI.getFlags(); LLT Ty = B.getMRI()->getType(Dst); - const LLT F16 = LLT::float16(); - const LLT F32 = LLT::float32(); + const LLT F16 = LLT::scalar(16); // TODO: Expected float16 + const LLT F32 = LLT::scalar(32); // TODO: Expected float32 if (Ty == F32) { auto Log = B.buildFLog2(F32, Src0, Flags); @@ -3944,7 +3946,7 @@ bool AMDGPULegalizerInfo::legalizeFFloor(MachineInstr &MI, MachineIRBuilder &B) const { const LLT S1 = LLT::scalar(1); - const LLT F64 = LLT::float64(); + const LLT F64 = LLT::scalar(64); // TODO: Expected float64 Register Dst = MI.getOperand(0).getReg(); Register OrigSrc = MI.getOperand(1).getReg(); unsigned Flags = MI.getFlags(); diff --git a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid0.mir b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid0.mir index cece3601dc1b2..95fd669a5d971 100644 --- a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid0.mir +++ b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid0.mir @@ -1,10 +1,10 @@ # RUN: not llc -mtriple=aarch64-- -run-pass none -o /dev/null %s 2>&1 | FileCheck %s -# When a low-level type is only a single 's'/'p' character +# When a low-level type is only a single 's'/'i'/'f'/'bf'/'p' type identifier --- name: test_low_level_type_is_single_s_p body: | bb.0: liveins: $x0 - ; CHECK: [[@LINE+1]]:10: expected integers after 's'/'p' type character + ; CHECK: [[@LINE+1]]:10: expected integers after 's'/'i'/'f'/'bf'/'p' type identifier %0:_(s) = COPY $x0 ... diff --git a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid1.mir b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid1.mir index 4a7b68dab623a..dd0bb73a6cf42 100644 --- a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid1.mir +++ b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid1.mir @@ -5,6 +5,6 @@ name: test_low_level_type_does_not_start_with_s_p_lt body: | bb.0: liveins: $x0 - ; CHECK: [[@LINE+1]]:10: expected sN, pA, , , , or for GlobalISel type - %0:_(i64) = COPY $x0 + ; CHECK: [[@LINE+1]]:10: expected tN, pA, , , , or for GlobalISel type, where t = {'s', 'i', 'f', 'bf'} + %0:_(n64) = COPY $x0 ... diff --git a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid2.mir b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid2.mir index 1bff7a5ec9ced..6277d24aacab3 100644 --- a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid2.mir +++ b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid2.mir @@ -5,6 +5,6 @@ name: test_low_level_type_is_single_s_p body: | bb.0: liveins: $q0 - ; CHECK: [[@LINE+1]]:15: expected integers after 's'/'p' type character + ; CHECK: [[@LINE+1]]:15: expected integers after 's'/'i'/'f'/'bf'/'p' type identifier %0:_(<2 x p>) = COPY $q0 ... diff --git a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid3.mir b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid3.mir index ebb3d37f9dfa1..ba652329c5337 100644 --- a/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid3.mir +++ b/llvm/test/CodeGen/MIR/AArch64/parse-low-level-type-invalid3.mir @@ -1,10 +1,10 @@ # RUN: not llc -mtriple=aarch64-- -run-pass none -o /dev/null %s 2>&1 | FileCheck %s -# When a low-level type is a vector which element type does not start with 's' or 'p' +# When a low-level type is a vector which element type does not start with right type identifier --- name: test_low_level_type_does_not_start_with_s_p body: | bb.0: liveins: $q0 - ; CHECK: [[@LINE+1]]:10: expected or for vector type - %0:_(<2 x i64>) = COPY $q0 + ; CHECK: [[@LINE+1]]:10: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} + %0:_(<2 x n64>) = COPY $q0 ... diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err0.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err0.mir index 5553d97acd003..6fc6640ea422a 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err0.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err0.mir @@ -6,5 +6,5 @@ body: | %0:_( or +# CHECK: expected , where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err1.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err1.mir index 12bfb82ebcd12..0017e01be3b04 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err1.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err1.mir @@ -6,4 +6,4 @@ body: | %0:_( or +# CHECK: expected , where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err10.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err10.mir index 7d7d7e49f23fe..37a9a0d2c159b 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err10.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err10.mir @@ -7,4 +7,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err12.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err12.mir index 5ced1aea30c08..9c840d92ac967 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err12.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err12.mir @@ -7,4 +7,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err13.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err13.mir index 94b8230233fa6..982adfdf792c2 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err13.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err13.mir @@ -7,4 +7,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err14.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err14.mir index 323e2d975692f..8a3e86ecaaab9 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err14.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err14.mir @@ -7,4 +7,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err15.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err15.mir index d1613869bf671..228b367b06d40 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err15.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err15.mir @@ -7,4 +7,4 @@ body: | %0:_(notatype) = IMPLICIT_DEF ... -# CHECK: expected sN, pA, , , , or for GlobalISel type +# CHECK: expected tN, pA, , , , or for GlobalISel type, where t = {'s', 'i', 'f', 'bf'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err2.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err2.mir index c504a7d6be249..b187bff218a75 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err2.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err2.mir @@ -6,4 +6,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err3.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err3.mir index c504a7d6be249..b187bff218a75 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err3.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err3.mir @@ -6,4 +6,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err4.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err4.mir index 654f534f4d301..1969f28b18d28 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err4.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err4.mir @@ -6,5 +6,5 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err5.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err5.mir index 26be2868c522e..49823f7432ba9 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err5.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err5.mir @@ -6,4 +6,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err6.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err6.mir index 07a30f57139dc..7f8e421e76ae3 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err6.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err6.mir @@ -7,5 +7,5 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err7.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err7.mir index dba902efe6331..ae3c472e2557c 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err7.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err7.mir @@ -6,4 +6,4 @@ body: | %0:_( or for vector type +# CHECK: expected for vector type, where t = {'s', 'i', 'f', 'bf', 'p'} diff --git a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err8.mir b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err8.mir index 8bedeabaa7906..9044f77f84dc5 100644 --- a/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err8.mir +++ b/llvm/test/CodeGen/MIR/Generic/scalable-vector-type-err8.mir @@ -7,4 +7,4 @@ body: | %0:_(&1 | FileCheck %s -# When a typed immediate operand is only a single 'i'/'s'/'p' character +# When a typed immediate operand is only a single type character --- name: test_typed_immediate_operand_invalid0 liveins: @@ -7,7 +7,7 @@ liveins: body: | bb.0: liveins: $arguments - ; CHECK: [[@LINE+1]]:24: expected integers after 'i'/'s'/'p' type character + ; CHECK: [[@LINE+1]]:24: expected integers after 'i'/'s'/'f'/'bf'/'p' type identifier %0:i32 = CONST_I32 i 0, implicit-def dead $arguments RETURN implicit-def dead $arguments ... diff --git a/llvm/test/CodeGen/MIR/WebAssembly/typed-immediate-operand-invalid1.mir b/llvm/test/CodeGen/MIR/WebAssembly/typed-immediate-operand-invalid1.mir index f5c16b52553bf..7f59a9e5eca0f 100644 --- a/llvm/test/CodeGen/MIR/WebAssembly/typed-immediate-operand-invalid1.mir +++ b/llvm/test/CodeGen/MIR/WebAssembly/typed-immediate-operand-invalid1.mir @@ -1,5 +1,5 @@ # RUN: not llc -mtriple=wasm32-unknown-unknown -run-pass none -o /dev/null %s 2>&1 | FileCheck %s -# When a typed immediate operand does not start with 'i', 's', or 'p' +# When a typed immediate operand does not start with right type identifier --- name: test_typed_immediate_operand_invalid1 liveins: @@ -7,7 +7,7 @@ liveins: body: | bb.0: liveins: $arguments - ; CHECK: [[@LINE+1]]:24: a typed immediate operand should start with one of 'i', 's', or 'p' + ; CHECK: [[@LINE+1]]:24: a typed immediate operand should start with one of 'i', 's', 'f', 'bf', or 'p' %0:i32 = CONST_I32 abc 0, implicit-def dead $arguments RETURN implicit-def dead $arguments ... diff --git a/llvm/unittests/CodeGen/GlobalISel/CMakeLists.txt b/llvm/unittests/CodeGen/GlobalISel/CMakeLists.txt index 4ef6aff943f73..e85a3aa72c961 100644 --- a/llvm/unittests/CodeGen/GlobalISel/CMakeLists.txt +++ b/llvm/unittests/CodeGen/GlobalISel/CMakeLists.txt @@ -14,6 +14,7 @@ set(LLVM_LINK_COMPONENTS ) add_llvm_unittest(GlobalISelTests + IRTranslatorBF16Test.cpp ConstantFoldingTest.cpp CSETest.cpp GIMatchTableExecutorTest.cpp diff --git a/llvm/unittests/CodeGen/GlobalISel/IRTranslatorBF16Test.cpp b/llvm/unittests/CodeGen/GlobalISel/IRTranslatorBF16Test.cpp new file mode 100644 index 0000000000000..c532634c19817 --- /dev/null +++ b/llvm/unittests/CodeGen/GlobalISel/IRTranslatorBF16Test.cpp @@ -0,0 +1,132 @@ +//===- IRTranslator.cpp - IRTranslator unit tests -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringRef.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/CodeGen/GlobalISel/IRTranslator.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/TargetFrameLowering.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; + +namespace { +struct AArch64IRTranslatorTest : public ::testing::Test { + LLVMContext C; + +public: + AArch64IRTranslatorTest() {} + std::unique_ptr createTargetMachine() const { + Triple TargetTriple("aarch64--"); + std::string Error; + const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); + if (!T) + return nullptr; + + TargetOptions Options; + return std::unique_ptr( + T->createTargetMachine(TargetTriple, "", "", Options, std::nullopt, + std::nullopt, CodeGenOptLevel::Aggressive)); + } + + std::unique_ptr parseIR(const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("Test TargetIRTranslator", errs()); + return Mod; + } +}; +} // namespace + +TEST_F(AArch64IRTranslatorTest, IRTranslateBfloat16) { + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmPrinters(); + InitializeAllAsmParsers(); + + PassRegistry *Registry = PassRegistry::getPassRegistry(); + initializeCore(*Registry); + initializeCodeGen(*Registry); + initializeGlobalISel(*Registry); + + std::unique_ptr M = parseIR(R"( + define void @foo(ptr %p0) { + %ptr1 = getelementptr bfloat, ptr %p0, i64 0 + %ptr2 = getelementptr bfloat, ptr %p0, i64 1 + %ptr3 = getelementptr bfloat, ptr %p0, i64 2 + %a = load bfloat, ptr %ptr1, align 2 + %b = load bfloat, ptr %ptr2, align 2 + %c = load bfloat, ptr %ptr3, align 2 + %mul = fmul bfloat %a, %b + %res = fadd bfloat %mul, %c + %ptr4 = getelementptr bfloat, ptr %p0, i64 3 + store bfloat %res, ptr %ptr4, align 2 + ret void + } + )"); + + auto TM = createTargetMachine(); + M->setDataLayout(TM->createDataLayout()); + + TM->setGlobalISel(true); + TM->setGlobalISelExtendedLLT(true); + TM->setGlobalISelAbort(GlobalISelAbortMode::DisableWithDiag); + + legacy::PassManager PM; + TargetPassConfig *TPC(TM->createPassConfig(PM)); + + MachineModuleInfoWrapperPass *MMIWP = + new MachineModuleInfoWrapperPass(TM.get()); + PM.add(TPC); + PM.add(MMIWP); + PM.add(new IRTranslator()); + PM.run(*M); + + auto *MMI = &MMIWP->getMMI(); + Function *F = M->getFunction("foo"); + auto *MF = MMI->getMachineFunction(*F); + MachineRegisterInfo &MRI = MF->getRegInfo(); + ASSERT_FALSE(MF->getProperties().hasProperty( + llvm::MachineFunctionProperties::Property::FailedISel)); + for (auto &MI : MF->front()) { + if (MI.getOpcode() == TargetOpcode::G_LOAD) { + ASSERT_EQ(MRI.getType(MI.getOperand(0).getReg()), LLT::bfloat16()); + } + + if (MI.getOpcode() == TargetOpcode::G_FADD || + MI.getOpcode() == TargetOpcode::G_FMUL) { + for (auto &Op : MI.operands()) { + ASSERT_EQ(MRI.getType(Op.getReg()), LLT::bfloat16()); + } + } + } + MMI->deleteMachineFunctionFor(*F); + + // Run again without extended LLT + TM->setGlobalISelExtendedLLT(false); + PM.run(*M); + MF = MMI->getMachineFunction(*F); + ASSERT_TRUE(MF->getProperties().hasProperty( + llvm::MachineFunctionProperties::Property::FailedISel)); +} diff --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp index 988e307909232..836c81b524672 100644 --- a/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerInfoTest.cpp @@ -420,12 +420,14 @@ TEST(LegalizerInfoTest, RuleSets) { // Raw type form LI.getActionDefinitionsBuilder(G_ADD) - .fewerElementsIf(typeIs(0, v4s32), changeElementCountTo(0, v2s32)) - .fewerElementsIf(typeIs(0, v8s32), changeElementCountTo(0, s32)) - .fewerElementsIf(typeIs(0, LLT::scalable_vector(4, 16)), - changeElementCountTo(0, LLT::scalable_vector(2, 16))) - .fewerElementsIf(typeIs(0, LLT::scalable_vector(8, 16)), - changeElementCountTo(0, s16)); + .fewerElementsIf(typeIs(0, v4s32), + changeElementCountTo(0, ElementCount::getFixed(2))) + .fewerElementsIf(typeIs(0, v8s32), + changeElementCountTo(0, ElementCount::getFixed(1))) + .fewerElementsIf(typeIs(0, LLT::scalable_vector(4, s16)), + changeElementCountTo(0, ElementCount::getScalable(2))) + .fewerElementsIf(typeIs(0, LLT::scalable_vector(8, s16)), + changeElementCountTo(0, ElementCount::getFixed(1))); LegacyInfo.computeTables(); diff --git a/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp b/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp index efaf05adc8f06..4af29cee0a3fe 100644 --- a/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp +++ b/llvm/utils/TableGen/Common/GlobalISel/GlobalISelMatchTable.cpp @@ -10,6 +10,7 @@ #include "Common/CodeGenInstruction.h" #include "Common/CodeGenRegisters.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/LEB128.h" #include "llvm/Support/ScopedPrinter.h" @@ -20,6 +21,11 @@ STATISTIC(NumPatternEmitted, "Number of patterns emitted"); +static llvm::cl::opt AllowExtendedLLT( + "gisel-extended-llt", + llvm::cl::desc("Generate an extended llt names in match tables"), + llvm::cl::init(false)); + namespace llvm { namespace gi { @@ -360,42 +366,98 @@ std::string LLTCodeGen::getCxxEnumValue() const { void LLTCodeGen::emitCxxEnumValue(raw_ostream &OS) const { if (Ty.isScalar()) { - OS << "GILLT_s" << Ty.getSizeInBits(); + if (Ty.isBFloat(16)) + OS << "GILLT_bf16"; + else if (Ty.isPPCF128()) + OS << "GILLT_ppcf128"; + else if (Ty.isX86FP80()) + OS << "GILLT_x86fp80"; + else if (Ty.isFloat()) + OS << "GILLT_f" << Ty.getSizeInBits(); + else if (Ty.isInteger()) + OS << "GILLT_i" << Ty.getSizeInBits(); + else + OS << "GILLT_s" << Ty.getSizeInBits(); return; } if (Ty.isVector()) { OS << (Ty.isScalable() ? "GILLT_nxv" : "GILLT_v") - << Ty.getElementCount().getKnownMinValue() << "s" - << Ty.getScalarSizeInBits(); + << Ty.getElementCount().getKnownMinValue(); + + LLT ElemTy = Ty.getElementType(); + if (ElemTy.isBFloat(16)) + OS << "bf16"; + else if (ElemTy.isPPCF128()) + OS << "ppcf128"; + else if (ElemTy.isX86FP80()) + OS << "x86fp80"; + else if (ElemTy.isFloat()) + OS << "f" << ElemTy.getSizeInBits(); + else if (Ty.isInteger()) + OS << "i" << ElemTy.getSizeInBits(); + else + OS << "s" << ElemTy.getSizeInBits(); return; } + if (Ty.isPointer()) { OS << "GILLT_p" << Ty.getAddressSpace(); if (Ty.getSizeInBits() > 0) OS << "s" << Ty.getSizeInBits(); return; } + llvm_unreachable("Unhandled LLT"); } void LLTCodeGen::emitCxxConstructorCall(raw_ostream &OS) const { if (Ty.isScalar()) { - OS << "LLT::scalar(" << Ty.getSizeInBits() << ")"; + if (Ty.isInteger()) + OS << "LLT::integer(" << Ty.getScalarSizeInBits() << ")"; + else if (Ty.isBFloat(16)) + OS << "LLT::bfloat()"; + else if (Ty.isPPCF128()) + OS << "LLT::ppcf128()"; + else if (Ty.isX86FP80()) + OS << "LLT::x86fp80()"; + else if (Ty.isFloat()) + OS << "LLT::floatingPoint(" << Ty.getScalarSizeInBits() + << ", LLT::FPVariant::IEEE_FLOAT)"; + else + OS << "LLT::scalar(" << Ty.getScalarSizeInBits() << ")"; return; } + if (Ty.isVector()) { OS << "LLT::vector(" << (Ty.isScalable() ? "ElementCount::getScalable(" : "ElementCount::getFixed(") - << Ty.getElementCount().getKnownMinValue() << "), " - << Ty.getScalarSizeInBits() << ")"; + << Ty.getElementCount().getKnownMinValue() << "), "; + + LLT ElemTy = Ty.getElementType(); + if (ElemTy.isInteger()) + OS << "LLT::integer(" << ElemTy.getScalarSizeInBits() << ")"; + else if (ElemTy.isBFloat(16)) + OS << "LLT::bfloat()"; + else if (ElemTy.isPPCF128()) + OS << "LLT::ppcf128()"; + else if (ElemTy.isX86FP80()) + OS << "LLT::x86fp80()"; + else if (ElemTy.isFloat()) + OS << "LLT::floatingPoint(" << ElemTy.getScalarSizeInBits() + << ", LLT::FPVariant::IEEE_FLOAT)"; + else + OS << "LLT::scalar(" << Ty.getScalarSizeInBits() << ")"; + OS << ")"; return; } + if (Ty.isPointer() && Ty.getSizeInBits() > 0) { OS << "LLT::pointer(" << Ty.getAddressSpace() << ", " << Ty.getSizeInBits() << ")"; return; } + llvm_unreachable("Unhandled LLT"); } @@ -441,11 +503,10 @@ std::optional MVTToLLT(MVT::SimpleValueType SVT) { MVT VT(SVT); if (VT.isVector() && !VT.getVectorElementCount().isScalar()) - return LLTCodeGen( - LLT::vector(VT.getVectorElementCount(), VT.getScalarSizeInBits())); + return LLTCodeGen(LLT(VT, AllowExtendedLLT)); if (VT.isInteger() || VT.isFloatingPoint()) - return LLTCodeGen(LLT::scalar(VT.getSizeInBits())); + return LLTCodeGen(LLT(VT, AllowExtendedLLT)); return std::nullopt; }