diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2937,14 +2937,33 @@ // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, bool Swap) -> GetElementPtrInst * { - Value *Ptr = Gep->getPointerOperand(); if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base || !Gep->hasOneUse()) return nullptr; + + auto *BaseGep = dyn_cast(Base); Type *ElementType = Gep->getResultElementType(); Value *Idx = Gep->getOperand(1); - Value *NewT = Idx; - Value *NewF = Constant::getNullValue(Idx->getType()); + Value *Ptr, *NewT, *NewF; + + // Handle nested geps special case. + // Fold (select C, (gep (gep Ptr, Idx0), Idx1), (gep Ptr, Idx0)) + // --> (gep Ptr,(select C, Idx0+Idx1, Idx0)) + // Fold (select C, (gep Ptr, Idx0), (gep (gep Ptr, Idx0), Idx1)) + // --> (gep Ptr,(select C, Idx0, Idx0+Idx1)) + if (BaseGep && BaseGep->getNumOperands() == 2 && + ElementType == BaseGep->getResultElementType() && + Idx->getType() == BaseGep->getOperand(1)->getType()) { + NewT = + Builder.CreateAdd(Idx, BaseGep->getOperand(1), SI.getName() + ".add"); + NewF = BaseGep->getOperand(1); + Ptr = BaseGep->getPointerOperand(); + } else { + NewT = Idx; + NewF = Constant::getNullValue(Idx->getType()); + Ptr = Gep->getPointerOperand(); + } + if (Swap) std::swap(NewT, NewF); Value *NewSI = diff --git a/llvm/test/Transforms/InstCombine/select-gep.ll b/llvm/test/Transforms/InstCombine/select-gep.ll --- a/llvm/test/Transforms/InstCombine/select-gep.ll +++ b/llvm/test/Transforms/InstCombine/select-gep.ll @@ -102,10 +102,10 @@ ; PR51069 define i32* @test2c(i32* %p, i64 %x, i64 %y) { ; CHECK-LABEL: @test2c( -; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 [[X:%.*]] -; CHECK-NEXT: [[ICMP:%.*]] = icmp ugt i64 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[SEL_IDX:%.*]] = select i1 [[ICMP]], i64 0, i64 6 -; CHECK-NEXT: [[SEL:%.*]] = getelementptr i32, i32* [[GEP1]], i64 [[SEL_IDX]] +; CHECK-NEXT: [[ICMP:%.*]] = icmp ugt i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SEL_ADD:%.*]] = add i64 [[X]], 6 +; CHECK-NEXT: [[SEL_IDX:%.*]] = select i1 [[ICMP]], i64 [[X]], i64 [[SEL_ADD]] +; CHECK-NEXT: [[SEL:%.*]] = getelementptr i32, i32* [[P:%.*]], i64 [[SEL_IDX]] ; CHECK-NEXT: ret i32* [[SEL]] ; %gep1 = getelementptr inbounds i32, i32* %p, i64 %x @@ -118,10 +118,10 @@ ; PR51069 define i32* @test2d(i32* %p, i64 %x, i64 %y) { ; CHECK-LABEL: @test2d( -; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 [[X:%.*]] -; CHECK-NEXT: [[ICMP:%.*]] = icmp ugt i64 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[SEL_IDX:%.*]] = select i1 [[ICMP]], i64 6, i64 0 -; CHECK-NEXT: [[SEL:%.*]] = getelementptr i32, i32* [[GEP1]], i64 [[SEL_IDX]] +; CHECK-NEXT: [[ICMP:%.*]] = icmp ugt i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SEL_ADD:%.*]] = add i64 [[X]], 6 +; CHECK-NEXT: [[SEL_IDX:%.*]] = select i1 [[ICMP]], i64 [[SEL_ADD]], i64 [[X]] +; CHECK-NEXT: [[SEL:%.*]] = getelementptr i32, i32* [[P:%.*]], i64 [[SEL_IDX]] ; CHECK-NEXT: ret i32* [[SEL]] ; %gep1 = getelementptr inbounds i32, i32* %p, i64 %x