Index: llvm/include/llvm/Analysis/Loads.h =================================================================== --- llvm/include/llvm/Analysis/Loads.h +++ llvm/include/llvm/Analysis/Loads.h @@ -173,14 +173,15 @@ unsigned MaxInstsToScan, AAResults *AA, bool *IsLoadCSE, unsigned *NumScanedInst); -/// Returns true if a pointer value \p A can be replace with another pointer -/// value \B if they are deemed equal through some means (e.g. information from +/// Returns true if a pointer value \p From can be replaced with another pointer +/// value \To if they are deemed equal through some means (e.g. information from /// conditions). -/// NOTE: the current implementations is incomplete and unsound. It does not -/// reject all invalid cases yet, but will be made stricter in the future. In -/// particular this means returning true means unknown if replacement is safe. -bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL, - Instruction *CtxI); +/// NOTE: The current implementation allows replacement in Icmp and PtrToInt +/// instructions, as well as when we are replacing with a null pointer. +/// Additionally it also allows replacement of pointers when both pointers have +/// the same underlying object. +bool canReplacePointersIfEqual(const Value *From, const Value *To, + const User *U); } #endif Index: llvm/include/llvm/Transforms/Utils/Local.h =================================================================== --- llvm/include/llvm/Transforms/Utils/Local.h +++ llvm/include/llvm/Transforms/Utils/Local.h @@ -407,6 +407,20 @@ /// the end of the given BasicBlock. Returns the number of replacements made. unsigned replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB); +/// Replace each use of 'From' with 'To' if that use is dominated by +/// the given edge and the callback ShouldReplace returns true. Returns the +/// number of replacements made. +unsigned replaceDominatedUsesWithIf( + Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Edge, + function_ref + ShouldReplace); +/// Replace each use of 'From' with 'To' if that use is dominated by +/// the end of the given BasicBlock and the callback ShouldReplace returns true. +/// Returns the number of replacements made. +unsigned replaceDominatedUsesWithIf( + Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB, + function_ref + ShouldReplace); /// Return true if this call calls a gc leaf function. /// Index: llvm/lib/Analysis/Loads.cpp =================================================================== --- llvm/lib/Analysis/Loads.cpp +++ llvm/lib/Analysis/Loads.cpp @@ -684,22 +684,34 @@ return Available; } -bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL, - Instruction *CtxI) { - Type *Ty = A->getType(); - assert(Ty == B->getType() && Ty->isPointerTy() && - "values must have matching pointer types"); - - // NOTE: The checks in the function are incomplete and currently miss illegal - // cases! The current implementation is a starting point and the - // implementation should be made stricter over time. - if (auto *C = dyn_cast(B)) { - // Do not allow replacing a pointer with a constant pointer, unless it is - // either null or at least one byte is dereferenceable. - APInt OneByte(DL.getPointerTypeSizeInBits(Ty), 1); - return C->isNullValue() || - isDereferenceableAndAlignedPointer(B, Align(1), OneByte, DL, CtxI); - } +static bool canReplacePointersRecursive(const Value *From, const Value *To, + const User *U, int MaxLookup = 6) { + if (MaxLookup == 0) + return false; + + if (isa(To)) + return true; + if (isa(U)) + return true; + if (isa(U)) + return true; + if (isa(U) && all_of(U->users(), [&](const User *User) { + return isa(User) || + canReplacePointersRecursive(U, To, User, MaxLookup - 1); + })) + return true; + if (isa(U) && all_of(U->users(), [&](const User *User) { + return canReplacePointersRecursive(U, To, User, MaxLookup - 1); + })) + return true; + return getUnderlyingObject(From) == getUnderlyingObject(To); +} - return true; +bool llvm::canReplacePointersIfEqual(const Value *From, const Value *To, + const User *U) { + // Not a pointer, just return true. + if (!From->getType()->isPointerTy()) + return true; + assert(From->getType() == To->getType() && "values must have matching types"); + return canReplacePointersRecursive(From, To, U); } Index: llvm/lib/Transforms/Scalar/GVN.cpp =================================================================== --- llvm/lib/Transforms/Scalar/GVN.cpp +++ llvm/lib/Transforms/Scalar/GVN.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionPrecedenceTracking.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" @@ -2329,14 +2330,18 @@ if (!LHS->hasOneUse()) { unsigned NumReplacements = DominatesByEdge - ? replaceDominatedUsesWith(LHS, RHS, *DT, Root) - : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart()); - - Changed |= NumReplacements > 0; - NumGVNEqProp += NumReplacements; - // Cached information for anything that uses LHS will be invalid. - if (MD) - MD->invalidateCachedPointerInfo(LHS); + ? replaceDominatedUsesWithIf(LHS, RHS, *DT, Root, + canReplacePointersIfEqual) + : replaceDominatedUsesWithIf(LHS, RHS, *DT, Root.getStart(), + canReplacePointersIfEqual); + + if (NumReplacements > 0) { + Changed = true; + NumGVNEqProp += NumReplacements; + // Cached information for anything that uses LHS will be invalid. + if (MD) + MD->invalidateCachedPointerInfo(LHS); + } } // Now try to deduce additional equalities from this one. For example, if @@ -2392,14 +2397,18 @@ if (NotCmp && isa(NotCmp)) { unsigned NumReplacements = DominatesByEdge - ? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root) - : replaceDominatedUsesWith(NotCmp, NotVal, *DT, - Root.getStart()); - Changed |= NumReplacements > 0; - NumGVNEqProp += NumReplacements; - // Cached information for anything that uses NotCmp will be invalid. - if (MD) - MD->invalidateCachedPointerInfo(NotCmp); + ? replaceDominatedUsesWithIf(NotCmp, NotVal, *DT, Root, + canReplacePointersIfEqual) + : replaceDominatedUsesWithIf(NotCmp, NotVal, *DT, + Root.getStart(), + canReplacePointersIfEqual); + if (NumReplacements > 0) { + Changed = true; + NumGVNEqProp += NumReplacements; + // Cached information for anything that uses NotCmp will be invalid. + if (MD) + MD->invalidateCachedPointerInfo(NotCmp); + } } } // Ensure that any instruction in scope that gets the "A < B" value number Index: llvm/lib/Transforms/Utils/Local.cpp =================================================================== --- llvm/lib/Transforms/Utils/Local.cpp +++ llvm/lib/Transforms/Utils/Local.cpp @@ -2841,15 +2841,20 @@ } template -static unsigned replaceDominatedUsesWith(Value *From, Value *To, - const RootType &Root, - const DominatesFn &Dominates) { +static unsigned replaceDominatedUsesWith( + Value *From, Value *To, const RootType &Root, const DominatesFn &Dominates, + std::optional< + function_ref> + ShouldReplace) { assert(From->getType() == To->getType()); unsigned Count = 0; for (Use &U : llvm::make_early_inc_range(From->uses())) { if (!Dominates(Root, U)) continue; + if (ShouldReplace.has_value() && + !ShouldReplace.value()(From, To, U.getUser())) + continue; U.set(To); LLVM_DEBUG(dbgs() << "Replace dominated use of '" << From->getName() << "' as " << *To << " in " << *U << "\n"); @@ -2879,7 +2884,7 @@ auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) { return DT.dominates(Root, U); }; - return ::replaceDominatedUsesWith(From, To, Root, Dominates); + return ::replaceDominatedUsesWith(From, To, Root, Dominates, std::nullopt); } unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, @@ -2888,7 +2893,27 @@ auto Dominates = [&DT](const BasicBlock *BB, const Use &U) { return DT.dominates(BB, U); }; - return ::replaceDominatedUsesWith(From, To, BB, Dominates); + return ::replaceDominatedUsesWith(From, To, BB, Dominates, std::nullopt); +} + +unsigned llvm::replaceDominatedUsesWithIf( + Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root, + function_ref + ShouldReplace) { + auto Dominates = [&DT](const BasicBlockEdge &Root, const Use &U) { + return DT.dominates(Root, U); + }; + return ::replaceDominatedUsesWith(From, To, Root, Dominates, ShouldReplace); +} + +unsigned llvm::replaceDominatedUsesWithIf( + Value *From, Value *To, DominatorTree &DT, const BasicBlock *BB, + function_ref + ShouldReplace) { + auto Dominates = [&DT](const BasicBlock *BB, const Use &U) { + return DT.dominates(BB, U); + }; + return ::replaceDominatedUsesWith(From, To, BB, Dominates, ShouldReplace); } bool llvm::callsGCLeafFunction(const CallBase *Call, Index: llvm/test/Transforms/GVN/assume-equal.ll =================================================================== --- llvm/test/Transforms/GVN/assume-equal.ll +++ llvm/test/Transforms/GVN/assume-equal.ll @@ -21,7 +21,7 @@ if.then: ; preds = %entry %0 = load ptr, ptr %vtable, align 8 - ; CHECK: call i32 @_ZN1A3fooEv( + ; CHECK-NOT: call i32 @_ZN1A3fooEv( %call2 = tail call i32 %0(ptr %call) #1 br label %if.end @@ -29,7 +29,7 @@ if.else: ; preds = %entry %vfn47 = getelementptr inbounds ptr, ptr %vtable, i64 1 - ; CHECK: call i32 @_ZN1A3barEv( + ; CHECK-NOT: call i32 @_ZN1A3barEv( %1 = load ptr, ptr %vfn47, align 8 %call5 = tail call i32 %1(ptr %call) #1 @@ -53,21 +53,21 @@ if.then: ; preds = %entry %0 = load ptr, ptr %vtable, align 8 -; CHECK: call i32 @_ZN1A3fooEv( +; CHECK-NOT: call i32 @_ZN1A3fooEv( %call2 = tail call i32 %0(ptr %call) #1 %vtable1 = load ptr, ptr %call, align 8, !invariant.group !0 %call1 = load ptr, ptr %vtable1, align 8 -; CHECK: call i32 @_ZN1A3fooEv( +; CHECK-NOT: call i32 @_ZN1A3fooEv( %callx = tail call i32 %call1(ptr %call) #1 %vtable2 = load ptr, ptr %call, align 8, !invariant.group !0 %call4 = load ptr, ptr %vtable2, align 8 -; CHECK: call i32 @_ZN1A3fooEv( +; CHECK-NOT: call i32 @_ZN1A3fooEv( %cally = tail call i32 %call4(ptr %call) #1 %vtable3 = load ptr, ptr %call, align 8, !invariant.group !0 %vfun = load ptr, ptr %vtable3, align 8 -; CHECK: call i32 @_ZN1A3fooEv( +; CHECK-NOT: call i32 @_ZN1A3fooEv( %unknown = tail call i32 %vfun(ptr %call) #1 br label %if.end @@ -75,7 +75,7 @@ if.else: ; preds = %entry %vfn47 = getelementptr inbounds ptr, ptr %vtable, i64 1 - ; CHECK: call i32 @_ZN1A3barEv( + ; CHECK-NOT: call i32 @_ZN1A3barEv( %1 = load ptr, ptr %vfn47, align 8 %call5 = tail call i32 %1(ptr %call) #1 Index: llvm/test/Transforms/GVN/condprop.ll =================================================================== --- llvm/test/Transforms/GVN/condprop.ll +++ llvm/test/Transforms/GVN/condprop.ll @@ -1,7 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=gvn -S | FileCheck %s -@a = external global i32 ; [#uses=7] +@a = external global i32 ; [#uses=8] define i32 @test1() nounwind { ; CHECK-LABEL: @test1( @@ -512,27 +512,22 @@ ret i32 %res } -; On the path from entry->if->end we know that ptr1==ptr2, so we can determine -; that gep2 does not alias ptr1 on that path (as it would require that -; ptr2==ptr2+2), so we can perform PRE of the load. +; Check that we dont propagate pointer equalities when illegal. define i32 @test13(ptr %ptr1, ptr %ptr2) { ; CHECK-LABEL: @test13( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i32, ptr [[PTR2:%.*]], i32 1 ; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i32, ptr [[PTR2]], i32 2 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[PTR1:%.*]], [[PTR2]] -; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[ENTRY_END_CRIT_EDGE:%.*]] -; CHECK: entry.end_crit_edge: -; CHECK-NEXT: [[VAL2_PRE:%.*]] = load i32, ptr [[GEP2]], align 4 -; CHECK-NEXT: br label [[END:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[IF:%.*]], label [[END:%.*]] ; CHECK: if: ; CHECK-NEXT: [[VAL1:%.*]] = load i32, ptr [[GEP2]], align 4 ; CHECK-NEXT: br label [[END]] ; CHECK: end: -; CHECK-NEXT: [[VAL2:%.*]] = phi i32 [ [[VAL1]], [[IF]] ], [ [[VAL2_PRE]], [[ENTRY_END_CRIT_EDGE]] ] -; CHECK-NEXT: [[PHI1:%.*]] = phi ptr [ [[PTR2]], [[IF]] ], [ [[GEP1]], [[ENTRY_END_CRIT_EDGE]] ] -; CHECK-NEXT: [[PHI2:%.*]] = phi i32 [ [[VAL1]], [[IF]] ], [ 0, [[ENTRY_END_CRIT_EDGE]] ] +; CHECK-NEXT: [[PHI1:%.*]] = phi ptr [ [[PTR1]], [[IF]] ], [ [[GEP1]], [[ENTRY:%.*]] ] +; CHECK-NEXT: [[PHI2:%.*]] = phi i32 [ [[VAL1]], [[IF]] ], [ 0, [[ENTRY]] ] ; CHECK-NEXT: store i32 0, ptr [[PHI1]], align 4 +; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[GEP2]], align 4 ; CHECK-NEXT: [[RET:%.*]] = add i32 [[PHI2]], [[VAL2]] ; CHECK-NEXT: ret i32 [[RET]] ; @@ -556,32 +551,29 @@ ret i32 %ret } -define void @test14(ptr %ptr1, ptr noalias %ptr2) { +define void @test14(ptr %ptr1, ptr noalias %ptr2, i1 %c1, i1 %c2) { ; CHECK-LABEL: @test14( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[GEP1:%.*]] = getelementptr inbounds i32, ptr [[PTR1:%.*]], i32 1 ; CHECK-NEXT: [[GEP2:%.*]] = getelementptr inbounds i32, ptr [[PTR1]], i32 2 ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: br i1 undef, label [[LOOP_IF1_CRIT_EDGE:%.*]], label [[THEN:%.*]] -; CHECK: loop.if1_crit_edge: -; CHECK-NEXT: [[VAL2_PRE:%.*]] = load i32, ptr [[GEP2]], align 4 -; CHECK-NEXT: br label [[IF1:%.*]] +; CHECK-NEXT: br i1 [[C1:%.*]], label [[IF1:%.*]], label [[THEN:%.*]] ; CHECK: if1: -; CHECK-NEXT: [[VAL2:%.*]] = phi i32 [ [[VAL2_PRE]], [[LOOP_IF1_CRIT_EDGE]] ], [ [[VAL3:%.*]], [[LOOP_END:%.*]] ] +; CHECK-NEXT: [[VAL2:%.*]] = load i32, ptr [[GEP2]], align 4 ; CHECK-NEXT: store i32 [[VAL2]], ptr [[GEP2]], align 4 ; CHECK-NEXT: store i32 0, ptr [[GEP1]], align 4 ; CHECK-NEXT: br label [[THEN]] ; CHECK: then: ; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[GEP2]], [[PTR2:%.*]] -; CHECK-NEXT: br i1 [[CMP]], label [[LOOP_END]], label [[IF2:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[LOOP_END:%.*]], label [[IF2:%.*]] ; CHECK: if2: ; CHECK-NEXT: br label [[LOOP_END]] ; CHECK: loop.end: -; CHECK-NEXT: [[PHI3:%.*]] = phi ptr [ [[PTR2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ] -; CHECK-NEXT: [[VAL3]] = load i32, ptr [[GEP2]], align 4 +; CHECK-NEXT: [[PHI3:%.*]] = phi ptr [ [[GEP2]], [[THEN]] ], [ [[PTR1]], [[IF2]] ] +; CHECK-NEXT: [[VAL3:%.*]] = load i32, ptr [[GEP2]], align 4 ; CHECK-NEXT: store i32 [[VAL3]], ptr [[PHI3]], align 4 -; CHECK-NEXT: br i1 undef, label [[LOOP]], label [[IF1]] +; CHECK-NEXT: br i1 [[C2:%.*]], label [[LOOP]], label [[IF1]] ; entry: %gep1 = getelementptr inbounds i32, ptr %ptr1, i32 1 @@ -590,7 +582,7 @@ loop: %phi1 = phi ptr [ %gep3, %loop.end ], [ %gep1, %entry ] - br i1 undef, label %if1, label %then + br i1 %c1, label %if1, label %then if1: @@ -611,5 +603,36 @@ %val3 = load i32, ptr %gep2, align 4 store i32 %val3, ptr %phi3, align 4 %gep3 = getelementptr inbounds i32, ptr %ptr1, i32 1 - br i1 undef, label %loop, label %if1 + br i1 %c2, label %loop, label %if1 +} + +; Make sure we dont return 1. +define i32 @test15(ptr noalias %p, i64 %i) { +; CHECK-LABEL: @test15( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i32, ptr [[P:%.*]], i64 [[I:%.*]] +; CHECK-NEXT: store i32 1, ptr [[ARRAYIDX]], align 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq ptr [[P]], @a +; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_END:%.*]] +; CHECK: if.then: +; CHECK-NEXT: store i32 2, ptr [[P]], align 4 +; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, ptr [[ARRAYIDX]], align 4 +; CHECK-NEXT: br label [[IF_END]] +; CHECK: if.end: +; CHECK-NEXT: [[TMP0:%.*]] = phi i32 [ [[DOTPRE]], [[IF_THEN]] ], [ 1, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[TMP0]] +; +entry: + %arrayidx = getelementptr inbounds i32, ptr %p, i64 %i + store i32 1, ptr %arrayidx, align 4 + %cmp = icmp eq ptr %p, @a + br i1 %cmp, label %if.then, label %if.end + +if.then: + store i32 2, ptr %p, align 4 + br label %if.end + +if.end: + %0 = load i32, ptr %arrayidx, align 4 + ret i32 %0 } Index: llvm/unittests/Analysis/LoadsTest.cpp =================================================================== --- llvm/unittests/Analysis/LoadsTest.cpp +++ llvm/unittests/Analysis/LoadsTest.cpp @@ -68,35 +68,47 @@ R"IR( @y = common global [1 x i32] zeroinitializer, align 4 @x = common global [1 x i32] zeroinitializer, align 4 - declare void @use(i32*) -define void @f(i32* %p) { +define void @f(i32* %p1, i32* %p2, i64 %i) { call void @use(i32* getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0)) - call void @use(i32* getelementptr inbounds (i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @x, i64 0, i64 0), i64 1)) + call void @use(i32* %p1) + + %p1_idx = getelementptr inbounds i32, i32* %p1, i64 %i + call void @use(i32* %p1_idx) + + %icmp = icmp eq i32* %p1, getelementptr inbounds ([1 x i32], [1 x i32]* @y, i64 0, i64 0) + %ptrInt = ptrtoint i32* %p1 to i64 ret void } )IR"); - const auto &DL = M->getDataLayout(); auto *GV = M->getNamedValue("f"); ASSERT_TRUE(GV); auto *F = dyn_cast(GV); ASSERT_TRUE(F); - // NOTE: the implementation of canReplacePointersIfEqual is incomplete. - // Currently the only the cases it returns false for are really sound and - // returning true means unknown. - Value *P = &*F->arg_begin(); + Value *P1 = &*F->arg_begin(); + Value *P2 = F->getArg(1); + Value *NullPtr = Constant::getNullValue(P1->getType()); auto InstIter = F->front().begin(); - Value *ConstDerefPtr = *cast(&*InstIter)->arg_begin(); - // ConstDerefPtr is a constant pointer that is provably de-referenceable. We - // can replace an arbitrary pointer with it. - EXPECT_TRUE(canReplacePointersIfEqual(P, ConstDerefPtr, DL, nullptr)); + CallInst *UserOfY = cast(&*InstIter); + CallInst *UserOfP1 = cast(&*++InstIter); + Value *ConstDerefPtr = UserOfY->getArgOperand(0); + // We cannot replace two pointers in arbitrary instructions unless we are + // replacing with null or they have the same underlying object. + EXPECT_FALSE(canReplacePointersIfEqual(ConstDerefPtr, P1, UserOfY)); + EXPECT_FALSE(canReplacePointersIfEqual(P1, ConstDerefPtr, UserOfP1)); + EXPECT_FALSE(canReplacePointersIfEqual(P1, P2, UserOfP1)); + EXPECT_TRUE(canReplacePointersIfEqual(P1, NullPtr, UserOfP1)); + + GetElementPtrInst *BasedOnPtr = cast(&*++InstIter); + CallInst *UserOfP1_Idx = cast(&*++InstIter); + EXPECT_TRUE(canReplacePointersIfEqual(BasedOnPtr, P1, UserOfP1_Idx)); + EXPECT_FALSE(canReplacePointersIfEqual(BasedOnPtr, P2, UserOfP1_Idx)); - ++InstIter; - Value *ConstUnDerefPtr = *cast(&*InstIter)->arg_begin(); - // ConstUndDerefPtr is a constant pointer that is provably not - // de-referenceable. We cannot replace an arbitrary pointer with it. - EXPECT_FALSE( - canReplacePointersIfEqual(ConstDerefPtr, ConstUnDerefPtr, DL, nullptr)); + // We can replace two arbitrary pointers in icmp and ptrtoint instructions. + ICmpInst *IcmpUser = cast(&*++InstIter); + PtrToIntInst *PtrToIntUser = cast(&*++InstIter); + EXPECT_TRUE(canReplacePointersIfEqual(P1, P2, IcmpUser)); + EXPECT_TRUE(canReplacePointersIfEqual(P1, P2, PtrToIntUser)); }