Index: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1890,7 +1890,8 @@ /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS, - Type *Ty, bool IsNUW) { + Type *Ty, BinaryOperator &I, + bool IsNUW) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize // this. bool Swapped = false; @@ -1913,6 +1914,65 @@ GEP1 = LHSGEP; GEP2 = RHSGEP; } + } else if (isa(LHSGEP->getPointerOperand())) { + // ( gep (PHI(X+A, X)), ...) - ( gep X, ...) + auto *PHI = dyn_cast(LHSGEP->getPointerOperand()); + if (PHI->getNumIncomingValues() == 2) { + auto *FirstInst = cast(PHI->getIncomingValue(0)); + auto *SecondInst = cast(PHI->getIncomingValue(1)); + + // Check if one of the PHI Node is same as the RHS and other is same as + // LHS. + if (FirstInst == LHS && SecondInst == RHS) { + // Verify if the GEP is indexed at incrementing addresses and the only + // use of SUB is to check if one pointer is higher than the other. + APInt Offset1(DL.getIndexTypeSizeInBits(FirstInst->getType()), 0); + FirstInst = FirstInst->stripAndAccumulateConstantOffsets( + DL, Offset1, /* AllowNonInbounds */ true); + APInt Offset2(DL.getIndexTypeSizeInBits(SecondInst->getType()), 0); + SecondInst = SecondInst->stripAndAccumulateConstantOffsets( + DL, Offset2, /* AllowNonInbounds */ true); + if (Offset1.slt(Offset2)) + return nullptr; + + // Check there is only one use of Substract. Handle scenarios where + // only use is a PHI or ashr/lshr(PHI) + if (I.hasOneUse()) { + PHINode *PHI2; + Instruction *PHIUser = cast(I.user_back()); + if (isa(PHIUser)) + PHI2 = dyn_cast(PHIUser); + // Not a 8-bit Pointer. Need to check shift amt as power of 2? + else if ((PHIUser->getOpcode() == Instruction::AShr || + PHIUser->getOpcode() == Instruction::LShr) && + PHIUser->hasOneUse() && + (isa(PHIUser->user_back()))) + PHI2 = dyn_cast(PHIUser->user_back()); + else + return nullptr; + + // PHI now is an early value or difference of 2 pointers. Can verify + // if other Incoming Values are 0. Only use of this PHI is a cmp or + // a cmp(or), means it reduces to a type bool. + for (const auto *U : PHI2->users()) { + ICmpInst::Predicate EqPred; + // Match to specific. Handling specific predicates. Can relax to + // any predicate when compared with Zero. + if (!(U->hasOneUse() && + ((match(U, m_Or(m_Specific(PHI2), m_Value())) && + match(U->user_back(), + m_ICmp(EqPred, m_Specific(U), m_Zero())) && + EqPred == ICmpInst::ICMP_EQ) || + (match(U, m_ICmp(EqPred, m_Specific(PHI2), m_Zero())) && + EqPred == ICmpInst::ICMP_NE)))) + return nullptr; + } + // If we have reached here the sub of 2 ptr2int's can be folded as + // X+A > X + GEP1 = LHSGEP; + } + } + } } } @@ -2391,14 +2451,14 @@ Value *LHSOp, *RHSOp; if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && match(Op1, m_PtrToInt(m_Value(RHSOp)))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), I, I.hasNoUnsignedWrap())) return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), I, /* IsNUW */ false)) return replaceInstUsesWith(I, Res); Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -93,8 +93,8 @@ Instruction *visitFNeg(UnaryOperator &I); Instruction *visitAdd(BinaryOperator &I); Instruction *visitFAdd(BinaryOperator &I); - Value *OptimizePointerDifference( - Value *LHS, Value *RHS, Type *Ty, bool isNUW); + Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty, + BinaryOperator &I, bool isNUW); Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Index: llvm/test/Transforms/InstCombine/sub-gep.ll =================================================================== --- llvm/test/Transforms/InstCombine/sub-gep.ll +++ llvm/test/Transforms/InstCombine/sub-gep.ll @@ -369,3 +369,104 @@ %i6 = lshr i64 %i5, 5 ret i64 %i6 } + +define i1 @_gep_phi1(ptr noundef %str1) { +; CHECK-LABEL: @_gep_phi1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_I:%.*]] = icmp eq ptr [[STR1:%.*]], null +; CHECK-NEXT: br i1 [[CMP_I]], label [[_Z3FOOPKC_EXIT:%.*]], label [[LOR_LHS_FALSE_I:%.*]] +; CHECK: lor.lhs.false.i: +; CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[STR1]], align 1 +; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i8 [[TMP0]], 0 +; CHECK-NEXT: br i1 [[CMP1_I]], label [[_Z3FOOPKC_EXIT]], label [[WHILE_COND_I:%.*]] +; CHECK: while.cond.i: +; CHECK-NEXT: [[A_PN_I:%.*]] = phi ptr [ [[TEST_0_I:%.*]], [[WHILE_COND_I]] ], [ [[STR1]], [[LOR_LHS_FALSE_I]] ] +; CHECK-NEXT: [[TEST_0_I]] = getelementptr inbounds i8, ptr [[A_PN_I]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[TEST_0_I]], align 1 +; CHECK-NEXT: [[CMP3_NOT_I:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: br i1 [[CMP3_NOT_I]], label [[WHILE_END_I:%.*]], label [[WHILE_COND_I]] +; CHECK: while.end.i: +; CHECK-NEXT: br label [[_Z3FOOPKC_EXIT]] +; CHECK: _Z3fooPKc.exit: +; CHECK-NEXT: [[TOBOOL:%.*]] = phi i1 [ true, [[WHILE_END_I]] ], [ false, [[LOR_LHS_FALSE_I]] ], [ false, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i1 [[TOBOOL]] +; +entry: + %cmp.i = icmp eq ptr %str1, null + br i1 %cmp.i, label %_Z3fooPKc.exit, label %lor.lhs.false.i + +lor.lhs.false.i: + %0 = load i8, ptr %str1, align 1 + %cmp1.i = icmp eq i8 %0, 0 + br i1 %cmp1.i, label %_Z3fooPKc.exit, label %while.cond.i + +while.cond.i: + %a.pn.i = phi ptr [ %test.0.i, %while.cond.i ], [ %str1, %lor.lhs.false.i ] + %test.0.i = getelementptr inbounds i8, ptr %a.pn.i, i64 1 + %1 = load i8, ptr %test.0.i, align 1 + %cmp3.not.i = icmp eq i8 %1, 0 + br i1 %cmp3.not.i, label %while.end.i, label %while.cond.i + +while.end.i: + %sub.ptr.lhs.cast.i = ptrtoint ptr %test.0.i to i64 + %sub.ptr.rhs.cast.i = ptrtoint ptr %str1 to i64 + %sub.ptr.sub.i = sub i64 %sub.ptr.lhs.cast.i, %sub.ptr.rhs.cast.i + br label %_Z3fooPKc.exit + +_Z3fooPKc.exit: + %retval.0.i = phi i64 [ %sub.ptr.sub.i, %while.end.i ], [ 0, %lor.lhs.false.i ], [ 0, %entry ] + %tobool = icmp ne i64 %retval.0.i, 0 + ret i1 %tobool +} + +define i1 @_gep_phi2(ptr noundef %str1, i64 %val2) { +; CHECK-LABEL: @_gep_phi2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_I:%.*]] = icmp eq ptr [[STR1:%.*]], null +; CHECK-NEXT: br i1 [[CMP_I]], label [[_Z3FOOPKC_EXIT:%.*]], label [[LOR_LHS_FALSE_I:%.*]] +; CHECK: lor.lhs.false.i: +; CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[STR1]], align 1 +; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i8 [[TMP0]], 0 +; CHECK-NEXT: br i1 [[CMP1_I]], label [[_Z3FOOPKC_EXIT]], label [[WHILE_COND_I:%.*]] +; CHECK: while.cond.i: +; CHECK-NEXT: [[A_PN_I:%.*]] = phi ptr [ [[TEST_0_I:%.*]], [[WHILE_COND_I]] ], [ [[STR1]], [[LOR_LHS_FALSE_I]] ] +; CHECK-NEXT: [[TEST_0_I]] = getelementptr inbounds i8, ptr [[A_PN_I]], i64 1 +; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr [[TEST_0_I]], align 1 +; CHECK-NEXT: [[CMP3_NOT_I:%.*]] = icmp eq i8 [[TMP1]], 0 +; CHECK-NEXT: br i1 [[CMP3_NOT_I]], label [[WHILE_END_I:%.*]], label [[WHILE_COND_I]] +; CHECK: while.end.i: +; CHECK-NEXT: br label [[_Z3FOOPKC_EXIT]] +; CHECK: _Z3fooPKc.exit: +; CHECK-NEXT: [[RETVAL_0_I:%.*]] = phi i64 [ 1, [[WHILE_END_I]] ], [ 0, [[LOR_LHS_FALSE_I]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: [[TMP2:%.*]] = or i64 [[RETVAL_0_I]], [[VAL2:%.*]] +; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i64 [[TMP2]], 0 +; CHECK-NEXT: ret i1 [[TOBOOL]] +; +entry: + %cmp.i = icmp eq ptr %str1, null + br i1 %cmp.i, label %_Z3fooPKc.exit, label %lor.lhs.false.i + +lor.lhs.false.i: + %0 = load i8, ptr %str1, align 1 + %cmp1.i = icmp eq i8 %0, 0 + br i1 %cmp1.i, label %_Z3fooPKc.exit, label %while.cond.i + +while.cond.i: + %a.pn.i = phi ptr [ %test.0.i, %while.cond.i ], [ %str1, %lor.lhs.false.i ] + %test.0.i = getelementptr inbounds i8, ptr %a.pn.i, i64 1 + %1 = load i8, ptr %test.0.i, align 1 + %cmp3.not.i = icmp eq i8 %1, 0 + br i1 %cmp3.not.i, label %while.end.i, label %while.cond.i + +while.end.i: + %sub.ptr.lhs.cast.i = ptrtoint ptr %test.0.i to i64 + %sub.ptr.rhs.cast.i = ptrtoint ptr %str1 to i64 + %sub.ptr.sub.i = sub i64 %sub.ptr.lhs.cast.i, %sub.ptr.rhs.cast.i + br label %_Z3fooPKc.exit + +_Z3fooPKc.exit: + %retval.0.i = phi i64 [ %sub.ptr.sub.i, %while.end.i ], [ 0, %lor.lhs.false.i ], [ 0, %entry ] + %2 = or i64 %retval.0.i, %val2 + %tobool = icmp eq i64 %2, 0 + ret i1 %tobool +}