diff --git a/llvm/include/llvm/Analysis/Loads.h b/llvm/include/llvm/Analysis/Loads.h --- a/llvm/include/llvm/Analysis/Loads.h +++ b/llvm/include/llvm/Analysis/Loads.h @@ -175,7 +175,7 @@ /// reject all invalid cases yet, but will be made stricter in the future. In /// particular this means returning true means unknown if replacement is safe. bool canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL, - Instruction *CtxI); + const Instruction *CtxI); } #endif diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstSimplifyFolder.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/OverflowInstAnalysis.h" @@ -4326,19 +4327,31 @@ // Note that the equivalence/replacement opportunity does not hold for vectors // because each element of a vector select is chosen independently. if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) { - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal || - simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal) + // In the case of pointers, even if two pointers are compared equal, + // they are not truly the same value. + // We need to check using canReplacePointersIfEqual if CmpLHS can be + // replaced with CmpRHS and vice versa. + bool CanReplaceLWithR = + !CmpLHS->getType()->isPointerTy() || + canReplacePointersIfEqual(CmpLHS, CmpRHS, Q.DL, Q.CxtI); + bool CanReplaceRWithL = + !CmpLHS->getType()->isPointerTy() || + canReplacePointersIfEqual(CmpRHS, CmpLHS, Q.DL, Q.CxtI); + bool AreEqual = CanReplaceLWithR && CanReplaceRWithL; + + if (AreEqual && (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ false, + MaxRecurse) == TrueVal || + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ false, + MaxRecurse) == TrueVal)) return FalseVal; - if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ true, MaxRecurse) == - FalseVal || - simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ true, MaxRecurse) == - FalseVal) + if ((CanReplaceLWithR && simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ true, + MaxRecurse) == FalseVal) || + (CanReplaceRWithL && simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ true, + MaxRecurse) == FalseVal)) return FalseVal; } diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -644,7 +644,7 @@ } bool llvm::canReplacePointersIfEqual(Value *A, Value *B, const DataLayout &DL, - Instruction *CtxI) { + const Instruction *CtxI) { Type *Ty = A->getType(); assert(Ty == B->getType() && Ty->isPointerTy() && "values must have matching pointer types"); diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" diff --git a/llvm/test/Transforms/InstSimplify/select-ptr-eq.ll b/llvm/test/Transforms/InstSimplify/select-ptr-eq.ll --- a/llvm/test/Transforms/InstSimplify/select-ptr-eq.ll +++ b/llvm/test/Transforms/InstSimplify/select-ptr-eq.ll @@ -5,7 +5,9 @@ define ptr @case1(ptr %p) { ; CHECK-LABEL: @case1( -; CHECK-NEXT: ret ptr @y +; CHECK-NEXT: [[C:%.*]] = icmp eq ptr [[P:%.*]], @y +; CHECK-NEXT: [[RES:%.*]] = select i1 [[C]], ptr [[P]], ptr @y +; CHECK-NEXT: ret ptr [[RES]] ; %c = icmp eq ptr %p, @y %res = select i1 %c, ptr %p, ptr @y @@ -14,7 +16,9 @@ define ptr @case2(ptr %p) { ; CHECK-LABEL: @case2( -; CHECK-NEXT: ret ptr @y +; CHECK-NEXT: [[C:%.*]] = icmp eq ptr @y, [[P:%.*]] +; CHECK-NEXT: [[RES:%.*]] = select i1 [[C]], ptr [[P]], ptr @y +; CHECK-NEXT: ret ptr [[RES]] ; %c = icmp eq ptr @y, %p %res = select i1 %c, ptr %p, ptr @y