diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -83,6 +83,10 @@ return false; } + bool isInBounds(const GEPOperator *Op) const { + return UseInstrInfo && Op->isInBounds(); + } + bool isExact(const BinaryOperator *Op) const { if (UseInstrInfo && isa(Op)) return cast(Op)->isExact(); @@ -100,13 +104,21 @@ // Wrapper to query additional information for instructions like metadata or // keywords like nsw, which provides conservative results if those cannot // be safely used. - const InstrInfoQuery IIQ; + InstrInfoQuery IIQ; /// Controls whether simplifications are allowed to constrain the range of /// possible values for uses of undef. If it is false, simplifications are not /// allowed to assume a particular value for a use of undef for example. bool CanUseUndef = true; + /// Controls whether returning more defined value is allowed. If it is false, + /// simplifications must not make the value more defined. These foldings + /// become illegal, for example: + /// undef => 0 + /// (x +nsw 1) >s x => true + /// (x +nsw 1) - 1 => x + bool AllowRefinement = true; + SimplifyQuery(const DataLayout &DL, const Instruction *CXTI = nullptr) : DL(DL), CxtI(CXTI) {} @@ -127,6 +139,20 @@ Copy.CanUseUndef = false; return Copy; } + SimplifyQuery getWithoutRefinement() const { + SimplifyQuery Copy = getWithoutUndef(); + Copy.AllowRefinement = false; + /// Actually, setting IIQ.UseInstrInfo to false disables more + /// transformations than necessary. For example, + /// (INT_MAX +nsw 1) => poison + /// is okay because it does not make the result more defined, but setting + /// the flag to false turns this folding off. + /// However, there are quite a few refining transformations that use + /// nsw/inbounds/... (e.g. "x +nsw 1 >s x" => true), and IIQ.UseInstrInfo + /// is set to false to easily opt out such transformations. + Copy.IIQ.UseInstrInfo = false; + return Copy; + } /// If CanUseUndef is true, returns whether \p V is undef. /// Otherwise always return false. @@ -297,7 +323,7 @@ /// AllowRefinement specifies whether the simplification can be a refinement, /// or whether it needs to be strictly identical. Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const SimplifyQuery &Q, bool AllowRefinement); + const SimplifyQuery &Q); /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively. /// diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -706,10 +706,15 @@ /// Compute the constant difference between two pointer values. /// If the difference is not a constant, returns zero. -static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, +static Constant *computePointerDifference(const SimplifyQuery &Q, Value *LHS, Value *RHS) { - Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); - Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + if (!Q.AllowRefinement) + // Folding e.g., '(gep p 1) - p' to '1' may result in returning more + // defined value. Conservatively bail out if AllowRefinement is false. + return nullptr; + + Constant *LHSOffset = stripAndComputeConstantOffsets(Q.DL, LHS); + Constant *RHSOffset = stripAndComputeConstantOffsets(Q.DL, RHS); // If LHS and RHS are not related via constant offsets to the same base // value, there is nothing we can do here. @@ -2488,6 +2493,11 @@ static Constant * computePointerICmp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q) { + if (!Q.AllowRefinement) + // Conservatively return nullptr since it may result in returning more + // defined value. + return nullptr; + const DataLayout &DL = Q.DL; const TargetLibraryInfo *TLI = Q.TLI; const DominatorTree *DT = Q.DT; @@ -3672,7 +3682,8 @@ if (GLHS->getPointerOperand() == GRHS->getPointerOperand() && GLHS->hasAllConstantIndices() && GRHS->hasAllConstantIndices() && (ICmpInst::isEquality(Pred) || - (GLHS->isInBounds() && GRHS->isInBounds() && + (Q.IIQ.isInBounds(cast(GLHS)) && + Q.IIQ.isInBounds(GRHS) && Pred == ICmpInst::getSignedPredicate(Pred)))) { // The bases are equal and the indices are constant. Build a constant // expression GEP with the same indices and a null base pointer to see @@ -3917,7 +3928,6 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, - bool AllowRefinement, unsigned MaxRecurse) { // Trivial replacement. if (V == Op) @@ -3939,7 +3949,7 @@ // We can't replace %sel with %add unless we strip away the flags (which will // be done in InstCombine). // TODO: This is unsound, because it only catches some forms of refinement. - if (!AllowRefinement && canCreatePoison(cast(I))) + if (!Q.AllowRefinement && canCreatePoison(cast(I))) return nullptr; // The simplification queries below may return the original value. Consider: @@ -4028,10 +4038,8 @@ } Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const SimplifyQuery &Q, - bool AllowRefinement) { - return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, - RecursionLimit); + const SimplifyQuery &Q) { + return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, RecursionLimit); } /// Try to simplify a select instruction when its condition operand is an @@ -4150,18 +4158,15 @@ // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal) + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, + Q.getWithoutRefinement(), + MaxRecurse) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, + Q.getWithoutRefinement(), MaxRecurse) == TrueVal) return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ true, MaxRecurse) == + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ true, MaxRecurse) == + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == FalseVal) return FalseVal; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1114,8 +1114,7 @@ Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); if (TrueVal != CmpLHS && isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) { - if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ true)) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ)) return replaceOperand(Sel, Swapped ? 2 : 1, V); // Even if TrueVal does not simplify, we can directly replace a use of @@ -1136,8 +1135,7 @@ } if (TrueVal != CmpRHS && isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) - if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ true)) + if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ)) return replaceOperand(Sel, Swapped ? 2 : 1, V); auto *FalseInst = dyn_cast(FalseVal); @@ -1167,10 +1165,10 @@ // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ false) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ false) == TrueVal) { + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, + SQ.getWithoutRefinement()) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, + SQ.getWithoutRefinement()) == TrueVal) { return replaceInstUsesWith(Sel, FalseVal); } diff --git a/llvm/test/Transforms/InstSimplify/pr49495.ll b/llvm/test/Transforms/InstSimplify/pr49495.ll --- a/llvm/test/Transforms/InstSimplify/pr49495.ll +++ b/llvm/test/Transforms/InstSimplify/pr49495.ll @@ -4,9 +4,11 @@ ; The first comparison (a != b) should not be dropped define i1 @test1(i8* %a, i8* %b) { ; CHECK-LABEL: @test1( -; CHECK-NEXT: [[A2:%.*]] = getelementptr inbounds i8, i8* [[A:%.*]], i64 -1 -; CHECK-NEXT: [[COND2:%.*]] = icmp ugt i8* [[A2]], [[B:%.*]] -; CHECK-NEXT: ret i1 [[COND2]] +; CHECK-NEXT: [[COND1:%.*]] = icmp ne i8* [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[A2:%.*]] = getelementptr inbounds i8, i8* [[A]], i64 -1 +; CHECK-NEXT: [[COND2:%.*]] = icmp ugt i8* [[A2]], [[B]] +; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND1]], i1 [[COND2]], i1 false +; CHECK-NEXT: ret i1 [[RES]] ; %cond1 = icmp ne i8* %a, %b %a2 = getelementptr inbounds i8, i8* %a, i64 -1 @@ -18,9 +20,11 @@ ; The first comparison (a != b) should not be dropped define i1 @test2(i32 %a, i32 %b) { ; CHECK-LABEL: @test2( -; CHECK-NEXT: [[A2:%.*]] = add nuw i32 [[A:%.*]], 1 -; CHECK-NEXT: [[COND2:%.*]] = icmp ult i32 [[A2]], [[B:%.*]] -; CHECK-NEXT: ret i1 [[COND2]] +; CHECK-NEXT: [[COND1:%.*]] = icmp ne i32 [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[A2:%.*]] = add nuw i32 [[A]], 1 +; CHECK-NEXT: [[COND2:%.*]] = icmp ult i32 [[A2]], [[B]] +; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND1]], i1 [[COND2]], i1 false +; CHECK-NEXT: ret i1 [[RES]] ; %cond1 = icmp ne i32 %a, %b %a2 = add nuw i32 %a, 1