Index: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1940,7 +1940,8 @@ /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, - Type *Ty, bool IsNUW) { + Type *Ty, BinaryOperator &I, + bool IsNUW) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize // this. bool Swapped = false; @@ -1963,6 +1964,63 @@ GEP1 = LHSGEP; GEP2 = RHSGEP; } + } else if (isa(LHSGEP->getPointerOperand())) { + // ( gep (PHI(X+A, X)), ...) - ( gep X, ...) + auto *PHI = dyn_cast(LHSGEP->getPointerOperand()); + if (PHI->getNumIncomingValues() == 2) { + auto *FirstInst = cast(PHI->getIncomingValue(0)); + auto *SecondInst = cast(PHI->getIncomingValue(1)); + + // Check if one of the PHI Node is same as the RHS and other is same as + // LHS. + if (FirstInst == LHS && SecondInst == RHS) { + // Verify if the GEP is indexed at incrementing addresses and the only + // use of SUB is to check if one pointer is higher than the other. + APInt Offset1(DL.getIndexTypeSizeInBits(FirstInst->getType()), 0); + FirstInst = FirstInst->stripAndAccumulateConstantOffsets( + DL, Offset1, /* AllowNonInbounds */ true); + APInt Offset2(DL.getIndexTypeSizeInBits(SecondInst->getType()), 0); + SecondInst = SecondInst->stripAndAccumulateConstantOffsets( + DL, Offset2, /* AllowNonInbounds */ true); + if (Offset1.slt(Offset2)) + return nullptr; + + // Check there is only one use of Substract. Handle scenarios where + // only use is a PHI or ashr/lshr(PHI) + if (I.hasOneUse()) { + PHINode *PHI2 = dyn_cast(I.user_back()); + if (!PHI2) { + // Not a 8-bit Pointer. Need to check shift amt as power of 2? + Instruction *PHIUser = cast(I.user_back()); + PHI2 = dyn_cast(PHIUser->user_back()); + if ((PHIUser->getOpcode() != Instruction::AShr && + PHIUser->getOpcode() != Instruction::LShr) || + !PHIUser->hasOneUse() || !PHI2) + return nullptr; + } + + // PHI now is an early value or difference of 2 pointers. Can verify + // if other Incoming Values are 0. Only use of this PHI is a cmp or + // a cmp(or), means it reduces to a type bool. + for (const auto *U : PHI2->users()) { + ICmpInst::Predicate EqPred; + // Match to specific. Handling specific predicates. Can relax to + // any predicate when compared with Zero. + if (!(U->hasOneUse() && + ((match(U, m_Or(m_Specific(PHI2), m_Value())) && + match(U->user_back(), + m_ICmp(EqPred, m_Specific(U), m_Zero())) && + EqPred == ICmpInst::ICMP_EQ) || + (match(U, m_ICmp(EqPred, m_Specific(PHI2), m_Zero())) && + EqPred == ICmpInst::ICMP_NE)))) + return nullptr; + } + // If we have reached here the sub of 2 ptr2int's can be folded as + // X+A > X + GEP1 = LHSGEP; + } + } + } } } @@ -2441,14 +2499,14 @@ Value *LHSOp, *RHSOp; if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && match(Op1, m_PtrToInt(m_Value(RHSOp)))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), I, I.hasNoUnsignedWrap())) return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), I, /* IsNUW */ false)) return replaceInstUsesWith(I, Res); Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -93,8 +93,6 @@ Instruction *visitFNeg(UnaryOperator &I); Instruction *visitAdd(BinaryOperator &I); Instruction *visitFAdd(BinaryOperator &I); - Value *OptimizePointerDifference( - Value *LHS, Value *RHS, Type *Ty, bool isNUW); Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); @@ -102,6 +100,8 @@ Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); Instruction *visitFRem(BinaryOperator &I); + Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty, + BinaryOperator &I, bool isNUW); bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I); Instruction *commonIRemTransforms(BinaryOperator &I); Instruction *commonIDivTransforms(BinaryOperator &I); Index: llvm/test/Transforms/InstCombine/sub-gep.ll =================================================================== --- llvm/test/Transforms/InstCombine/sub-gep.ll +++ llvm/test/Transforms/InstCombine/sub-gep.ll @@ -386,11 +386,10 @@ ; CHECK-NEXT: [[CMP3_NOT_I:%.*]] = icmp eq i8 [[TMP1]], 0 ; CHECK-NEXT: br i1 [[CMP3_NOT_I]], label [[WHILE_END_I:%.*]], label [[WHILE_COND_I]] ; CHECK: while.end.i: -; CHECK-NEXT: [[TMP2:%.*]] = icmp ne ptr [[TEST_0_I]], [[STR1]] ; CHECK-NEXT: br label [[_Z3FOOPKC_EXIT]] ; CHECK: _Z3fooPKc.exit: -; CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi i1 [ [[TMP2]], [[WHILE_END_I]] ], [ false, [[LOR_LHS_FALSE_I]] ], [ false, [[ENTRY:%.*]] ] -; CHECK-NEXT: ret i1 [[RETVAL_0_I]] +; CHECK-NEXT: [[TOBOOL:%.*]] = phi i1 [ true, [[WHILE_END_I]] ], [ false, [[LOR_LHS_FALSE_I]] ], [ false, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i1 [[TOBOOL]] ; entry: %cmp.i = icmp eq ptr %str1, null @@ -436,12 +435,9 @@ ; CHECK-NEXT: [[CMP3_NOT_I:%.*]] = icmp eq i8 [[TMP1]], 0 ; CHECK-NEXT: br i1 [[CMP3_NOT_I]], label [[WHILE_END_I:%.*]], label [[WHILE_COND_I]] ; CHECK: while.end.i: -; CHECK-NEXT: [[SUB_PTR_LHS_CAST_I:%.*]] = ptrtoint ptr [[TEST_0_I]] to i64 -; CHECK-NEXT: [[SUB_PTR_RHS_CAST_I:%.*]] = ptrtoint ptr [[STR1]] to i64 -; CHECK-NEXT: [[SUB_PTR_SUB_I:%.*]] = sub i64 [[SUB_PTR_LHS_CAST_I]], [[SUB_PTR_RHS_CAST_I]] ; CHECK-NEXT: br label [[_Z3FOOPKC_EXIT]] ; CHECK: _Z3fooPKc.exit: -; CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi i64 [ [[SUB_PTR_SUB_I]], [[WHILE_END_I]] ], [ 0, [[LOR_LHS_FALSE_I]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi i64 [ 1, [[WHILE_END_I]] ], [ 0, [[LOR_LHS_FALSE_I]] ], [ 0, [[ENTRY:%.*]] ] ; CHECK-NEXT: [[TMP2:%.*]] = or i64 [[RETVAL_0_I]], [[VAL2:%.*]] ; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[TMP2]], 0 ; CHECK-NEXT: ret i1 [[TOBOOL]] @@ -474,3 +470,56 @@ %tobool = icmp eq i64 %2, 0 ret i1 %tobool } + +define i1 @_gep_phi3(ptr noundef %str1, i64 %val2) { +; CHECK-LABEL: @_gep_phi3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_I:%.*]] = icmp eq ptr [[STR1:%.*]], null +; CHECK-NEXT: br i1 [[CMP_I]], label [[_Z3FOOPKC_EXIT:%.*]], label [[LOR_LHS_FALSE_I:%.*]] +; CHECK: lor.lhs.false.i: +; CHECK-NEXT: [[TMP0:%.*]] = load i16, ptr [[STR1]], align 2 +; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i16 [[TMP0]], 0 +; CHECK-NEXT: br i1 [[CMP1_I]], label [[_Z3FOOPKC_EXIT]], label [[WHILE_COND_I:%.*]] +; CHECK: while.cond.i: +; CHECK-NEXT: [[A_PN_I:%.*]] = phi ptr [ [[TEST_0_I:%.*]], [[WHILE_COND_I]] ], [ [[STR1]], [[LOR_LHS_FALSE_I]] ] +; CHECK-NEXT: [[TEST_0_I]] = getelementptr inbounds i16, ptr [[A_PN_I]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = load i16, ptr [[TEST_0_I]], align 2 +; CHECK-NEXT: [[CMP3_NOT_I:%.*]] = icmp eq i16 [[TMP1]], 0 +; CHECK-NEXT: br i1 [[CMP3_NOT_I]], label [[WHILE_END_I:%.*]], label [[WHILE_COND_I]] +; CHECK: while.end.i: +; CHECK-NEXT: br label [[_Z3FOOPKC_EXIT]] +; CHECK: _Z3fooPKc.exit: +; CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi i64 [ 1, [[WHILE_END_I]] ], [ 0, [[LOR_LHS_FALSE_I]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: [[TMP2:%.*]] = or i64 [[RETVAL_0_I]], [[VAL2:%.*]] +; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[TMP2]], 0 +; CHECK-NEXT: ret i1 [[TOBOOL]] +; +entry: + %cmp.i = icmp eq ptr %str1, null + br i1 %cmp.i, label %_Z3fooPKc.exit, label %lor.lhs.false.i + +lor.lhs.false.i: + %0 = load i16, ptr %str1, align 2 + %cmp1.i = icmp eq i16 %0, 0 + br i1 %cmp1.i, label %_Z3fooPKc.exit, label %while.cond.i + +while.cond.i: + %a.pn.i = phi ptr [ %test.0.i, %while.cond.i ], [ %str1, %lor.lhs.false.i ] + %test.0.i = getelementptr inbounds i16, ptr %a.pn.i, i64 1 + %1 = load i16, ptr %test.0.i, align 2 + %cmp3.not.i = icmp eq i16 %1, 0 + br i1 %cmp3.not.i, label %while.end.i, label %while.cond.i + +while.end.i: + %sub.ptr.lhs.cast.i = ptrtoint ptr %test.0.i to i64 + %sub.ptr.rhs.cast.i = ptrtoint ptr %str1 to i64 + %sub.ptr.sub.i = sub i64 %sub.ptr.lhs.cast.i, %sub.ptr.rhs.cast.i + %sub.ptr.div = ashr exact i64 %sub.ptr.sub.i, 1 + br label %_Z3fooPKc.exit + +_Z3fooPKc.exit: + %retval.0.i = phi i64 [ %sub.ptr.div, %while.end.i ], [ 0, %lor.lhs.false.i ], [ 0, %entry ] + %2 = or i64 %retval.0.i, %val2 + %tobool = icmp eq i64 %2, 0 + ret i1 %tobool +}