diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2933,6 +2933,33 @@ if (Instruction *I = foldSelectExtConst(SI)) return I; + // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0)) + // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) + auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, + bool Swap) -> GetElementPtrInst * { + Value *Ptr = Gep->getPointerOperand(); + if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base || + !Gep->hasOneUse()) + return nullptr; + Type *ElementType = Gep->getResultElementType(); + Value *Idx = Gep->getOperand(1); + Value *NewT = Idx; + Value *NewF = Constant::getNullValue(Idx->getType()); + if (Swap) + std::swap(NewT, NewF); + Value *NewSI = + Builder.CreateSelect(CondVal, NewT, NewF, SI.getName() + ".idx", &SI); + return Gep->isInBounds() + ? GetElementPtrInst::CreateInBounds(ElementType, Ptr, {NewSI}) + : GetElementPtrInst::Create(ElementType, Ptr, {NewSI}); + }; + if (auto *TrueGep = dyn_cast(TrueVal)) + if (auto *NewGep = SelectGepWithBase(TrueGep, FalseVal, false)) + return NewGep; + if (auto *FalseGep = dyn_cast(FalseVal)) + if (auto *NewGep = SelectGepWithBase(FalseGep, TrueVal, true)) + return NewGep; + // See if we can fold the select into one of our operands. if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) diff --git a/llvm/test/Transforms/InstCombine/select-gep.ll b/llvm/test/Transforms/InstCombine/select-gep.ll --- a/llvm/test/Transforms/InstCombine/select-gep.ll +++ b/llvm/test/Transforms/InstCombine/select-gep.ll @@ -74,9 +74,9 @@ ; PR50183 define i32* @test2a(i32* %p, i64 %x, i64 %y) { ; CHECK-LABEL: @test2a( -; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 [[X:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i64 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], i32* [[GEP]], i32* [[P]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SELECT_IDX:%.*]] = select i1 [[CMP]], i64 [[X]], i64 0 +; CHECK-NEXT: [[SELECT:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 [[SELECT_IDX]] ; CHECK-NEXT: ret i32* [[SELECT]] ; %gep = getelementptr inbounds i32, i32* %p, i64 %x @@ -88,9 +88,9 @@ ; PR50183 define i32* @test2b(i32* %p, i64 %x, i64 %y) { ; CHECK-LABEL: @test2b( -; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 [[X:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i64 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], i32* [[P]], i32* [[GEP]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[SELECT_IDX:%.*]] = select i1 [[CMP]], i64 0, i64 [[X]] +; CHECK-NEXT: [[SELECT:%.*]] = getelementptr inbounds i32, i32* [[P:%.*]], i64 [[SELECT_IDX]] ; CHECK-NEXT: ret i32* [[SELECT]] ; %gep = getelementptr inbounds i32, i32* %p, i64 %x