diff --git a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp index d53a3144bf57d..a814867652cd1 100644 --- a/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp +++ b/llvm/lib/Transforms/Utils/RelLookupTableConverter.cpp @@ -21,29 +21,20 @@ using namespace llvm; -static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { +struct LookupTableInfo { + Value *Index; + SmallVector Ptrs; +}; + +static bool shouldConvertToRelLookupTable(LookupTableInfo &Info, Module &M, + GlobalVariable &GV) { // If lookup table has more than one user, // do not generate a relative lookup table. // This is to simplify the analysis that needs to be done for this pass. // TODO: Add support for lookup tables with multiple uses. // For ex, this can happen when a function that uses a lookup table gets // inlined into multiple call sites. - if (!GV.hasInitializer() || - !GV.isConstant() || - !GV.hasOneUse()) - return false; - - GetElementPtrInst *GEP = - dyn_cast(GV.use_begin()->getUser()); - if (!GEP || !GEP->hasOneUse() || - GV.getValueType() != GEP->getSourceElementType()) - return false; - - LoadInst *Load = dyn_cast(GEP->use_begin()->getUser()); - if (!Load || !Load->hasOneUse() || - Load->getType() != GEP->getResultElementType()) - return false; - + // // If the original lookup table does not have local linkage and is // not dso_local, do not generate a relative lookup table. // This optimization creates a relative lookup table that consists of @@ -51,21 +42,40 @@ static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { // To be able to generate these offsets, relative lookup table and // its elements should have internal linkage and be dso_local, which means // that they should resolve to symbols within the same linkage unit. - if (!GV.hasLocalLinkage() || - !GV.isDSOLocal() || - !GV.isImplicitDSOLocal()) + if (!GV.hasInitializer() || !GV.isConstant() || !GV.hasOneUse() || + !GV.hasLocalLinkage() || !GV.isDSOLocal() || !GV.isImplicitDSOLocal()) return false; - ConstantArray *Array = dyn_cast(GV.getInitializer()); - if (!Array) + auto *GEP = dyn_cast(GV.use_begin()->getUser()); + if (!GEP || !GEP->hasOneUse()) + return false; + + auto *Load = dyn_cast(GEP->use_begin()->getUser()); + if (!Load || !Load->hasOneUse()) return false; // If values are not 64-bit pointers, do not generate a relative lookup table. const DataLayout &DL = M.getDataLayout(); - Type *ElemType = Array->getType()->getElementType(); + Type *ElemType = Load->getType(); if (!ElemType->isPointerTy() || DL.getPointerTypeSizeInBits(ElemType) != 64) return false; + // Make sure this is a gep of the form GV + scale*var. + unsigned IndexWidth = + DL.getIndexTypeSizeInBits(Load->getPointerOperand()->getType()); + SmallMapVector VarOffsets; + APInt ConstOffset(IndexWidth, 0); + if (!GEP->collectOffset(DL, IndexWidth, VarOffsets, ConstOffset) || + !ConstOffset.isZero() || VarOffsets.size() != 1) + return false; + + // This can't be a pointer lookup table if the stride is smaller than a + // pointer. + Info.Index = VarOffsets.front().first; + const APInt &Stride = VarOffsets.front().second; + if (Stride.ult(DL.getTypeStoreSize(ElemType))) + return false; + SmallVector GVOps; Triple TT = M.getTargetTriple(); // FIXME: This should be removed in the future. @@ -80,14 +90,20 @@ static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { // https://github.com/rust-lang/rust/issues/141306. || (TT.isX86() && TT.isOSDarwin()); - for (const Use &Op : Array->operands()) { - Constant *ConstOp = cast(&Op); + APInt Offset(IndexWidth, 0); + uint64_t GVSize = DL.getTypeAllocSize(GV.getValueType()); + for (; Offset.ult(GVSize); Offset += Stride) { + Constant *C = + ConstantFoldLoadFromConst(GV.getInitializer(), ElemType, Offset, DL); + if (!C) + return false; + GlobalValue *GVOp; - APInt Offset; + APInt GVOffset; // If an operand is not a constant offset from a lookup table, // do not generate a relative lookup table. - if (!IsConstantOffsetFromGlobal(ConstOp, GVOp, Offset, DL)) + if (!IsConstantOffsetFromGlobal(C, GVOp, GVOffset, DL)) return false; // If operand is mutable, do not generate a relative lookup table. @@ -102,6 +118,8 @@ static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { if (ShouldDropUnnamedAddr) GVOps.push_back(GlovalVarOp); + + Info.Ptrs.push_back(C); } if (ShouldDropUnnamedAddr) @@ -111,14 +129,12 @@ static bool shouldConvertToRelLookupTable(Module &M, GlobalVariable &GV) { return true; } -static GlobalVariable *createRelLookupTable(Function &Func, +static GlobalVariable *createRelLookupTable(LookupTableInfo &Info, + Function &Func, GlobalVariable &LookupTable) { Module &M = *Func.getParent(); - ConstantArray *LookupTableArr = - cast(LookupTable.getInitializer()); - unsigned NumElts = LookupTableArr->getType()->getNumElements(); ArrayType *IntArrayTy = - ArrayType::get(Type::getInt32Ty(M.getContext()), NumElts); + ArrayType::get(Type::getInt32Ty(M.getContext()), Info.Ptrs.size()); GlobalVariable *RelLookupTable = new GlobalVariable( M, IntArrayTy, LookupTable.isConstant(), LookupTable.getLinkage(), @@ -127,10 +143,9 @@ static GlobalVariable *createRelLookupTable(Function &Func, LookupTable.isExternallyInitialized()); uint64_t Idx = 0; - SmallVector RelLookupTableContents(NumElts); + SmallVector RelLookupTableContents(Info.Ptrs.size()); - for (Use &Operand : LookupTableArr->operands()) { - Constant *Element = cast(Operand); + for (Constant *Element : Info.Ptrs) { Type *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext()); Constant *Base = llvm::ConstantExpr::getPtrToInt(RelLookupTable, IntPtrTy); Constant *Target = llvm::ConstantExpr::getPtrToInt(Element, IntPtrTy); @@ -148,7 +163,8 @@ static GlobalVariable *createRelLookupTable(Function &Func, return RelLookupTable; } -static void convertToRelLookupTable(GlobalVariable &LookupTable) { +static void convertToRelLookupTable(LookupTableInfo &Info, + GlobalVariable &LookupTable) { GetElementPtrInst *GEP = cast(LookupTable.use_begin()->getUser()); LoadInst *Load = cast(GEP->use_begin()->getUser()); @@ -159,21 +175,21 @@ static void convertToRelLookupTable(GlobalVariable &LookupTable) { Function &Func = *BB->getParent(); // Generate an array that consists of relative offsets. - GlobalVariable *RelLookupTable = createRelLookupTable(Func, LookupTable); + GlobalVariable *RelLookupTable = + createRelLookupTable(Info, Func, LookupTable); // Place new instruction sequence before GEP. Builder.SetInsertPoint(GEP); - Value *Index = GEP->getOperand(2); - IntegerType *IntTy = cast(Index->getType()); - Value *Offset = - Builder.CreateShl(Index, ConstantInt::get(IntTy, 2), "reltable.shift"); + IntegerType *IntTy = cast(Info.Index->getType()); + Value *Offset = Builder.CreateShl(Info.Index, ConstantInt::get(IntTy, 2), + "reltable.shift"); // Insert the call to load.relative intrinsic before LOAD. // GEP might not be immediately followed by a LOAD, like it can be hoisted // outside the loop or another instruction might be inserted them in between. Builder.SetInsertPoint(Load); Function *LoadRelIntrinsic = llvm::Intrinsic::getOrInsertDeclaration( - &M, Intrinsic::load_relative, {Index->getType()}); + &M, Intrinsic::load_relative, {Info.Index->getType()}); // Create a call to load.relative intrinsic that computes the target address // by adding base address (lookup table address) and relative offset. @@ -205,10 +221,11 @@ static bool convertToRelativeLookupTables( bool Changed = false; for (GlobalVariable &GV : llvm::make_early_inc_range(M.globals())) { - if (!shouldConvertToRelLookupTable(M, GV)) + LookupTableInfo Info; + if (!shouldConvertToRelLookupTable(Info, M, GV)) continue; - convertToRelLookupTable(GV); + convertToRelLookupTable(Info, GV); // Remove the original lookup table. GV.eraseFromParent(); diff --git a/llvm/test/Transforms/RelLookupTableConverter/X86/relative_lookup_table.ll b/llvm/test/Transforms/RelLookupTableConverter/X86/relative_lookup_table.ll index 16511725a7f1a..efd644543c016 100644 --- a/llvm/test/Transforms/RelLookupTableConverter/X86/relative_lookup_table.ll +++ b/llvm/test/Transforms/RelLookupTableConverter/X86/relative_lookup_table.ll @@ -69,6 +69,35 @@ target triple = "x86_64-unknown-linux-gnu" ptr @.str.9 ], align 16 +@table3 = internal constant [2 x ptr] [ + ptr @.str.8, + ptr @.str.9 +], align 16 + +@table4 = internal constant [2 x ptr] [ + ptr @.str.8, + ptr @.str.9 +], align 16 + +@table5 = internal constant [2 x ptr] [ + ptr @.str.8, + ptr @.str.9 +], align 16 + +@skip.table = internal constant [4 x ptr] [ + ptr @.str.8, + ptr null, + ptr @.str.9, + ptr null +], align 16 + +@wrong.skip.table = internal constant [4 x ptr] [ + ptr null, + ptr @.str.8, + ptr null, + ptr @.str.9 +], align 16 + ;. ; CHECK: @.str = private unnamed_addr constant [5 x i8] c"zero\00", align 1 ; CHECK: @.str.1 = private unnamed_addr constant [4 x i8] c"one\00", align 1 @@ -97,6 +126,11 @@ target triple = "x86_64-unknown-linux-gnu" ; CHECK: @user_defined_lookup_table.table.rel = internal unnamed_addr constant [3 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr @.str to i64), i64 ptrtoint (ptr @user_defined_lookup_table.table.rel to i64)) to i32), i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.1 to i64), i64 ptrtoint (ptr @user_defined_lookup_table.table.rel to i64)) to i32), i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.2 to i64), i64 ptrtoint (ptr @user_defined_lookup_table.table.rel to i64)) to i32)], align 4 ; CHECK: @table.rel = internal unnamed_addr constant [2 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.8 to i64), i64 ptrtoint (ptr @table.rel to i64)) to i32), i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.9 to i64), i64 ptrtoint (ptr @table.rel to i64)) to i32)], align 4 ; CHECK: @table2.rel = internal unnamed_addr constant [2 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.8 to i64), i64 ptrtoint (ptr @table2.rel to i64)) to i32), i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.9 to i64), i64 ptrtoint (ptr @table2.rel to i64)) to i32)], align 4 +; CHECK: @table3.rel = internal unnamed_addr constant [2 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.8 to i64), i64 ptrtoint (ptr @table3.rel to i64)) to i32), i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.9 to i64), i64 ptrtoint (ptr @table3.rel to i64)) to i32)], align 4 +; CHECK: @table4 = internal constant [2 x ptr] [ptr @.str.8, ptr @.str.9], align 16 +; CHECK: @table5 = internal constant [2 x ptr] [ptr @.str.8, ptr @.str.9], align 16 +; CHECK: @skip.table.rel = internal unnamed_addr constant [2 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.8 to i64), i64 ptrtoint (ptr @skip.table.rel to i64)) to i32), i32 trunc (i64 sub (i64 ptrtoint (ptr @.str.9 to i64), i64 ptrtoint (ptr @skip.table.rel to i64)) to i32)], align 4 +; CHECK: @wrong.skip.table = internal constant [4 x ptr] [ptr null, ptr @.str.8, ptr null, ptr @.str.9], align 16 ;. define ptr @external_linkage(i32 %cond) { ; CHECK-LABEL: define ptr @external_linkage( @@ -323,6 +357,69 @@ entry: ret ptr %1 } +define ptr @gep_no_leading_zero(i64 %index) { +; CHECK-LABEL: define ptr @gep_no_leading_zero( +; CHECK-SAME: i64 [[INDEX:%.*]]) { +; CHECK-NEXT: [[RELTABLE_SHIFT:%.*]] = shl i64 [[INDEX]], 2 +; CHECK-NEXT: [[LOAD:%.*]] = call ptr @llvm.load.relative.i64(ptr @table3.rel, i64 [[RELTABLE_SHIFT]]) +; CHECK-NEXT: ret ptr [[LOAD]] +; + %gep = getelementptr ptr, ptr @table3, i64 %index + %load = load ptr, ptr %gep + ret ptr %load +} + +define ptr @gep_wrong_stride(i64 %index) { +; CHECK-LABEL: define ptr @gep_wrong_stride( +; CHECK-SAME: i64 [[INDEX:%.*]]) { +; CHECK-NEXT: [[GEP:%.*]] = getelementptr i8, ptr @table4, i64 [[INDEX]] +; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[GEP]], align 8 +; CHECK-NEXT: ret ptr [[LOAD]] +; + %gep = getelementptr i8, ptr @table4, i64 %index + %load = load ptr, ptr %gep + ret ptr %load +} + +define ptr @gep_wrong_constant_offset(i64 %index) { +; CHECK-LABEL: define ptr @gep_wrong_constant_offset( +; CHECK-SAME: i64 [[INDEX:%.*]]) { +; CHECK-NEXT: [[GEP:%.*]] = getelementptr { ptr, i32 }, ptr @table5, i64 [[INDEX]], i32 1 +; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[GEP]], align 8 +; CHECK-NEXT: ret ptr [[LOAD]] +; + %gep = getelementptr { ptr, i32 }, ptr @table5, i64 %index, i32 1 + %load = load ptr, ptr %gep + ret ptr %load +} + +; This is intentionally advancing by two pointers. +define ptr @table_with_skipped_elements(i64 %index) { +; CHECK-LABEL: define ptr @table_with_skipped_elements( +; CHECK-SAME: i64 [[INDEX:%.*]]) { +; CHECK-NEXT: [[RELTABLE_SHIFT:%.*]] = shl i64 [[INDEX]], 2 +; CHECK-NEXT: [[LOAD:%.*]] = call ptr @llvm.load.relative.i64(ptr @skip.table.rel, i64 [[RELTABLE_SHIFT]]) +; CHECK-NEXT: ret ptr [[LOAD]] +; + %gep = getelementptr [2 x ptr], ptr @skip.table, i64 %index + %load = load ptr, ptr %gep + ret ptr %load +} + +; Same as previous test, but the elements are at the wrong position in the +; table. +define ptr @table_with_skipped_elements_wrong(i64 %index) { +; CHECK-LABEL: define ptr @table_with_skipped_elements_wrong( +; CHECK-SAME: i64 [[INDEX:%.*]]) { +; CHECK-NEXT: [[GEP:%.*]] = getelementptr [2 x ptr], ptr @wrong.skip.table, i64 [[INDEX]] +; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[GEP]], align 8 +; CHECK-NEXT: ret ptr [[LOAD]] +; + %gep = getelementptr [2 x ptr], ptr @wrong.skip.table, i64 %index + %load = load ptr, ptr %gep + ret ptr %load +} + !llvm.module.flags = !{!0, !1} !0 = !{i32 7, !"PIC Level", i32 2} !1 = !{i32 1, !"Code Model", i32 1}