diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -148,6 +148,7 @@ Instruction *SliceUpIllegalIntegerPHI(PHINode &PN); Instruction *visitPHINode(PHINode &PN); Instruction *visitGetElementPtrInst(GetElementPtrInst &GEP); + Value *splitConstantGEP(Value *Src); Instruction *visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src); Instruction *visitGEPOfBitcast(BitCastInst *BCI, GetElementPtrInst &GEP); Instruction *visitAllocaInst(AllocaInst &AI); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1935,6 +1935,28 @@ return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); } +// For a chain of GEP in the form GEP (GEP ... (GEP (GEP p C1) X1) ... Xn) C2 +// where C1 and C2 are constants, and GEP X1 ... Xn each has no other uses, +// we can split GEP p C1 from the use-def chain so that it can be reassociated +// and merged with the outermost GEP by visitGEPofGEP. +Value *InstCombinerImpl::splitConstantGEP(Value *Src) { + if (GetElementPtrInst *GEP = dyn_cast(Src)) { + APInt Offset(DL.getIndexTypeSizeInBits(GEP->getType()), 0); + if (GEP->hasAllConstantIndices() && + GEP->accumulateConstantOffset(DL, Offset)) + return Builder.CreateGEP(GEP->getSourceElementType(), + GEP->getPointerOperand(), + SmallVector(GEP->indices()), + GEP->getName(), GEP->isInBounds()); + if (Src->hasOneUse()) + if (Value *NewV = splitConstantGEP(GEP->getPointerOperand())) + return Builder.CreateGEP(GEP->getSourceElementType(), NewV, + SmallVector(GEP->indices()), + GEP->getName(), GEP->isInBounds()); + } + return nullptr; +} + Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, GEPOperator *Src) { // Combine Indices - If the source pointer to this getelementptr instruction @@ -2444,6 +2466,19 @@ if (Instruction *I = visitGEPOfGEP(GEP, Src)) return I; + // This check must be after visitGEPOfGEP, since GEP (GEP p C1) C2 can be + // merged by it. If GEP has zero offset, do not transform because it is a zero + // cost bitcast. + APInt Offset(DL.getIndexTypeSizeInBits(GEP.getType()), 0); + if (GEP.hasAllConstantIndices() && !GEP.hasAllZeroIndices() && + GEP.accumulateConstantOffset(DL, Offset) && !Offset.isZero()) + if (Value *V = splitConstantGEP(GEP.getPointerOperand())) { + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP.getSourceElementType(), V, SmallVector(GEP.indices())); + NewGEP->setIsInBounds(GEP.isInBounds()); + return NewGEP; + } + // Skip if GEP source element type is scalable. The type alloc size is unknown // at compile-time. if (GEP.getNumIndices() == 1 && !IsGEPSrcEleScalable) { diff --git a/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll b/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll --- a/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll +++ b/llvm/test/Transforms/InstCombine/gep-merge-constant-indices.ll @@ -9,6 +9,8 @@ %struct.B = type { i8, [3 x i16], %struct.A, float } %struct.C = type { i8, i32, i32 } +declare void @use(ptr %p) + ; result = (i32*) p + 3 define ptr @mergeBasic(ptr %p) { ; CHECK-LABEL: @mergeBasic( @@ -220,3 +222,76 @@ %2 = getelementptr inbounds i32, ptr %1, i64 1 ret ptr %2 } + +; If the first and last in a chain of GEP are constant indexed, they can be +; merged, even if the first GEP has multiple uses. +; result = (i8*) ((i32*) p + a + b) + 22 +define ptr @gepChain1(ptr %p, i64 %a, i64 %b) { +; CHECK-LABEL: @gepChain1( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 2 +; CHECK-NEXT: call void @use(ptr nonnull [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P]], i64 [[A:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[TMP2]], i64 [[B:%.*]] +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP3]], i64 22 +; CHECK-NEXT: ret ptr [[TMP4]] +; + %1 = getelementptr inbounds i64, ptr %p, i64 2 + call void @use(ptr %1) + %2 = getelementptr inbounds i32, ptr %1, i64 %a + %3 = getelementptr inbounds i32, ptr %2, i64 %b + %4 = getelementptr inbounds i16, ptr %3, i64 3 + ret ptr %4 +} + +; Constant indexed GEP cancel out. +; result = (i32*) p + a + b +define ptr @gepChain2(ptr %p, i64 %a, i64 %b) { +; CHECK-LABEL: @gepChain2( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 2 +; CHECK-NEXT: call void @use(ptr nonnull [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P]], i64 [[A:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[TMP2]], i64 [[B:%.*]] +; CHECK-NEXT: ret ptr [[TMP3]] +; + %1 = getelementptr inbounds i64, ptr %p, i64 2 + call void @use(ptr %1) + %2 = getelementptr inbounds i32, ptr %1, i64 %a + %3 = getelementptr inbounds i32, ptr %2, i64 %b + %4 = getelementptr inbounds i16, ptr %3, i64 -8 + ret ptr %4 +} + +; Negative test. If the last GEP offset is zero, do not merge since it is a +; bitcast with no cost. +define ptr @gepChainZeroOffset(ptr %p, i64 %a) { +; CHECK-LABEL: @gepChainZeroOffset( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 2 +; CHECK-NEXT: call void @use(ptr nonnull [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i64 [[A:%.*]] +; CHECK-NEXT: ret ptr [[TMP2]] +; + %1 = getelementptr inbounds i64, ptr %p, i64 2 + call void @use(ptr %1) + %2 = getelementptr inbounds i32, ptr %1, i64 %a + %3 = getelementptr inbounds [4 x i16], ptr %2, i64 1, i64 -4 + ret ptr %3 +} + +; Negative test. Do not merge if a GEP in the middle has multiple use, such that +; the transform is invalid. +define ptr @gepChainMultipleUse(ptr %p, i64 %a) { +; CHECK-LABEL: @gepChainMultipleUse( +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 2 +; CHECK-NEXT: call void @use(ptr nonnull [[TMP1]]) +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i64 [[A:%.*]] +; CHECK-NEXT: call void @use(ptr nonnull [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i32, ptr [[TMP2]], i64 3 +; CHECK-NEXT: ret ptr [[TMP3]] +; + %1 = getelementptr inbounds i64, ptr %p, i64 2 + call void @use(ptr %1) + %2 = getelementptr inbounds i32, ptr %1, i64 %a + call void @use(ptr %2) + %3 = getelementptr inbounds i32, ptr %2, i64 3 + ret ptr %3 +}