diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -148,6 +148,7 @@ Instruction *SliceUpIllegalIntegerPHI(PHINode &PN); Instruction *visitPHINode(PHINode &PN); Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); + Value *splitConstantGEP(Value *Src); Instruction *visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src); Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP); Instruction *visitAllocaInst(AllocaInst &AI); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1935,6 +1935,28 @@ return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); } +// For a chain of GEP in the form GEP (GEP ... (GEP (GEP p C1) X1) ... Xn) C2 +// where C1 and C2 are constants, and GEP X1 ... Xn each has no other uses, +// we can split GEP p C1 from the use-def chain so that it can be reassociated +// and merged with the outermost GEP by visitGEPofGEP. +Value *InstCombinerImpl::splitConstantGEP(Value *Src) { + if (GetElementPtrInst *GEP = dyn_cast(Src)) { + APInt Offset(DL.getIndexTypeSizeInBits(GEP->getType()), 0); + if (GEP->hasAllConstantIndices() && + GEP->accumulateConstantOffset(DL, Offset)) + return Builder.CreateGEP(GEP->getSourceElementType(), + GEP->getPointerOperand(), + SmallVector(GEP->indices()), + GEP->getName(), GEP->isInBounds()); + if (Src->hasOneUse()) + if (Value *NewV = splitConstantGEP(GEP->getPointerOperand())) + return Builder.CreateGEP(GEP->getSourceElementType(), NewV, + SmallVector(GEP->indices()), + GEP->getName(), GEP->isInBounds()); + } + return nullptr; +} + Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src) { // Combine Indices - If the source pointer to this getelementptr instruction @@ -2444,6 +2466,19 @@ if (Instruction *I = visitGEPOfGEP(GEP, Src)) return I; + // This check must be after visitGEPOfGEP, since GEP (GEP p C1) C2 can be + // merged by it. If GEP has zero offset, do not transform because it is a zero + // cost bitcast. + APInt Offset(DL.getIndexTypeSizeInBits(GEP.getType()), 0); + if (GEP.hasAllConstantIndices() && !GEP.hasAllZeroIndices() && + GEP.accumulateConstantOffset(DL, Offset) && !Offset.isZero()) + if (Value *V = splitConstantGEP(GEP.getPointerOperand())) { + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP.getSourceElementType(), V, SmallVector(GEP.indices())); + NewGEP->setIsInBounds(GEP.isInBounds()); + return NewGEP; + } + // Skip if GEP source element type is scalable. The type alloc size is unknown // at compile-time. if (GEP.getNumIndices() == 1 && !IsGEPSrcEleScalable) { diff --git a/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll b/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll --- a/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll +++ b/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll @@ -230,9 +230,9 @@ ; CHECK-LABEL: @gepChain1( ; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 2 ; CHECK-NEXT: call void @use(ptr nonnull [[TMP1]]) -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i64 [[A:%.*]] -; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i64 [[B:%.*]] -; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[TMP3]], i64 3 +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P]], i64 [[A:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[TMP2]], i64 [[B:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP3]], i64 22 ; CHECK-NEXT: ret ptr [[TMP4]] ; %1 = getelementptr inbounds i64, ptr %p, i64 2 @@ -249,10 +249,9 @@ ; CHECK-LABEL: @gepChain2( ; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 2 ; CHECK-NEXT: call void @use(ptr nonnull [[TMP1]]) -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i64 [[A:%.*]] -; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i64 [[B:%.*]] -; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i16, ptr [[TMP3]], i64 -8 -; CHECK-NEXT: ret ptr [[TMP4]] +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P]], i64 [[A:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[TMP2]], i64 [[B:%.*]] +; CHECK-NEXT: ret ptr [[TMP3]] ; %1 = getelementptr inbounds i64, ptr %p, i64 2 call void @use(ptr %1)