diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -899,10 +899,7 @@ return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } -Instruction *InstCombinerImpl::foldAllocaCmp(ICmpInst &ICI, - const AllocaInst *Alloca) { - assert(ICI.isEquality() && "Cannot fold non-equality comparison."); - +bool InstCombinerImpl::foldAllocaCmp(AllocaInst *Alloca) { // It would be tempting to fold away comparisons between allocas and any // pointer not based on that alloca (e.g. an argument). However, even // though such pointers cannot alias, they can still compare equal. @@ -911,21 +908,34 @@ // doesn't escape we can argue that it's impossible to guess its value, and we // can therefore act as if any such guesses are wrong. // - // The code below checks that the alloca doesn't escape, and that it's only - // used in a comparison once (the current instruction). The - // single-comparison-use condition ensures that we're trivially folding all - // comparisons against the alloca consistently, and avoids the risk of - // erroneously folding a comparison of the pointer with itself. + // However, we need to ensure that this folding is consistent: We can't fold + // one comparison to false, and then leave a different comparison against the + // same value alone (as it might evaluate to true at runtime, leading to a + // contradiction). As such, this code ensures that all comparisons are folded + // at the same time, and there are no other escapes. struct CmpCaptureTracker : public CaptureTracker { + AllocaInst *Alloca; bool Captured = false; - unsigned NumCmps = 0; + /// The value of the map is a bit mask of which icmp operands the alloca is + /// used in. + SmallMapVector ICmps; + + CmpCaptureTracker(AllocaInst *Alloca) : Alloca(Alloca) {} void tooManyUses() override { Captured = true; } bool captured(const Use *U) override { - if (isa(U->getUser()) && ++NumCmps == 1) { - // Ignore one icmp capture. + auto *ICmp = dyn_cast(U->getUser()); + // We need to check that U is based *only* on the alloca, and doesn't + // have other contributions from a select/phi operand. + // TODO: We could check whether getUnderlyingObjects() reduces to one + // object, which would allow looking through phi nodes. + if (ICmp && ICmp->isEquality() && getUnderlyingObject(*U) == Alloca) { + // Collect equality icmps of the alloca, and don't treat them as + // captures. + auto Res = ICmps.insert({ICmp, 0}); + Res.first->second |= 1u << U->getOperandNo(); return false; } @@ -934,14 +944,36 @@ } }; - CmpCaptureTracker Tracker; + CmpCaptureTracker Tracker(Alloca); PointerMayBeCaptured(Alloca, &Tracker); if (Tracker.Captured) - return nullptr; + return false; + + bool Changed = false; + for (auto [ICmp, Operands] : Tracker.ICmps) { + switch (Operands) { + case 1: + case 2: { + // The alloca is only used in one icmp operand. Assume that the + // equality is false. + auto *Res = ConstantInt::get( + ICmp->getType(), ICmp->getPredicate() == ICmpInst::ICMP_NE); + replaceInstUsesWith(*ICmp, Res); + eraseInstFromFunction(*ICmp); + Changed = true; + break; + } + case 3: + // Both icmp operands are based on the alloca, so this is comparing + // pointer offsets, without leaking any information about the address + // of the alloca. Ignore such comparisons. + break; + default: + llvm_unreachable("Cannot happen"); + } + } - auto *Res = ConstantInt::get(ICI.getType(), - !CmpInst::isTrueWhenEqual(ICI.getPredicate())); - return replaceInstUsesWith(ICI, Res); + return Changed; } /// Fold "icmp pred (X+C), X". @@ -6500,11 +6532,11 @@ if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast(getUnderlyingObject(Op0))) - if (Instruction *New = foldAllocaCmp(I, Alloca)) - return New; + if (foldAllocaCmp(Alloca)) + return nullptr; if (auto *Alloca = dyn_cast(getUnderlyingObject(Op1))) - if (Instruction *New = foldAllocaCmp(I, Alloca)) - return New; + if (foldAllocaCmp(Alloca)) + return nullptr; } if (Instruction *Res = foldICmpBitCast(I)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -549,7 +549,7 @@ ICmpInst::Predicate Cond, Instruction &I); Instruction *foldSelectICmp(ICmpInst::Predicate Pred, SelectInst *SI, Value *RHS, const ICmpInst &I); - Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca); + bool foldAllocaCmp(AllocaInst *Alloca); Instruction *foldCmpLoadFromIndexedGlobal(LoadInst *LI, GetElementPtrInst *GEP, GlobalVariable *GV, CmpInst &ICI, diff --git a/llvm/test/Transforms/InstCombine/compare-alloca.ll b/llvm/test/Transforms/InstCombine/compare-alloca.ll --- a/llvm/test/Transforms/InstCombine/compare-alloca.ll +++ b/llvm/test/Transforms/InstCombine/compare-alloca.ll @@ -58,12 +58,7 @@ declare void @check_compares(i1, i1) define void @alloca_argument_compare_two_compares(ptr %p) { ; CHECK-LABEL: @alloca_argument_compare_two_compares( -; CHECK-NEXT: [[Q1:%.*]] = alloca [8 x i64], align 8 -; CHECK-NEXT: [[R:%.*]] = getelementptr i64, ptr [[P:%.*]], i32 1 -; CHECK-NEXT: [[S:%.*]] = getelementptr inbounds i64, ptr [[Q1]], i32 2 -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq ptr [[Q1]], [[P]] -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq ptr [[R]], [[S]] -; CHECK-NEXT: call void @check_compares(i1 [[CMP1]], i1 [[CMP2]]) +; CHECK-NEXT: call void @check_compares(i1 false, i1 false) ; CHECK-NEXT: ret void ; %q = alloca i64, i64 8 @@ -154,13 +149,10 @@ declare void @witness(i1, i1) -define void @neg_consistent_fold1() { -; CHECK-LABEL: @neg_consistent_fold1( -; CHECK-NEXT: [[M1:%.*]] = alloca [4 x i8], align 1 +define void @consistent_fold1() { +; CHECK-LABEL: @consistent_fold1( ; CHECK-NEXT: [[RHS2:%.*]] = call ptr @hidden_inttoptr() -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq ptr [[M1]], inttoptr (i64 2048 to ptr) -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq ptr [[M1]], [[RHS2]] -; CHECK-NEXT: call void @witness(i1 [[CMP1]], i1 [[CMP2]]) +; CHECK-NEXT: call void @witness(i1 false, i1 false) ; CHECK-NEXT: ret void ; %m = alloca i8, i32 4 @@ -172,15 +164,11 @@ ret void } -define void @neg_consistent_fold2() { -; CHECK-LABEL: @neg_consistent_fold2( -; CHECK-NEXT: [[M1:%.*]] = alloca [4 x i8], align 1 +define void @consistent_fold2() { +; CHECK-LABEL: @consistent_fold2( ; CHECK-NEXT: [[N2:%.*]] = alloca [4 x i8], align 1 -; CHECK-NEXT: [[RHS:%.*]] = getelementptr inbounds i8, ptr [[N2]], i32 4 ; CHECK-NEXT: [[RHS2:%.*]] = call ptr @hidden_offset(ptr nonnull [[N2]]) -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq ptr [[M1]], [[RHS]] -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq ptr [[M1]], [[RHS2]] -; CHECK-NEXT: call void @witness(i1 [[CMP1]], i1 [[CMP2]]) +; CHECK-NEXT: call void @witness(i1 false, i1 false) ; CHECK-NEXT: ret void ; %m = alloca i8, i32 4 @@ -193,14 +181,10 @@ ret void } -define void @neg_consistent_fold3() { -; CHECK-LABEL: @neg_consistent_fold3( -; CHECK-NEXT: [[M1:%.*]] = alloca [4 x i8], align 1 -; CHECK-NEXT: [[LGP:%.*]] = load ptr, ptr @gp, align 8 +define void @consistent_fold3() { +; CHECK-LABEL: @consistent_fold3( ; CHECK-NEXT: [[RHS2:%.*]] = call ptr @hidden_inttoptr() -; CHECK-NEXT: [[CMP1:%.*]] = icmp eq ptr [[M1]], [[LGP]] -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq ptr [[M1]], [[RHS2]] -; CHECK-NEXT: call void @witness(i1 [[CMP1]], i1 [[CMP2]]) +; CHECK-NEXT: call void @witness(i1 false, i1 false) ; CHECK-NEXT: ret void ; %m = alloca i8, i32 4