Index: llvm/include/llvm/Analysis/InstructionSimplify.h =================================================================== --- llvm/include/llvm/Analysis/InstructionSimplify.h +++ llvm/include/llvm/Analysis/InstructionSimplify.h @@ -292,6 +292,12 @@ Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, OptimizationRemarkEmitter *ORE = nullptr); +/// See if V simplifies when its operand Op is replaced with RepOp. +/// AllowRefinement specifies whether the simplification can be a refinement, +/// or whether it needs to be strictly identical. +Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, bool AllowRefinement); + /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively. /// /// This first performs a normal RAUW of I with SimpleV. It then recursively Index: llvm/lib/Analysis/InstructionSimplify.cpp =================================================================== --- llvm/lib/Analysis/InstructionSimplify.cpp +++ llvm/lib/Analysis/InstructionSimplify.cpp @@ -3769,10 +3769,10 @@ return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } -/// See if V simplifies when its operand Op is replaced with RepOp. -static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const SimplifyQuery &Q, - unsigned MaxRecurse) { +static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + bool AllowRefinement, + unsigned MaxRecurse) { // Trivial replacement. if (V == Op) return RepOp; @@ -3785,20 +3785,19 @@ if (!I) return nullptr; + // Consider: + // %cmp = icmp eq i32 %x, 2147483647 + // %add = add nsw i32 %x, 1 + // %sel = select i1 %cmp, i32 -2147483648, i32 %add + // + // We can't replace %sel with %add unless we strip away the flags (which will + // be done in InstCombine). + // TODO: This is unsound, because it only catches some forms of refinement. + if (!AllowRefinement && canCreatePoison(cast(I))) + return nullptr; + // If this is a binary operator, try to simplify it with the replaced op. if (auto *B = dyn_cast(I)) { - // Consider: - // %cmp = icmp eq i32 %x, 2147483647 - // %add = add nsw i32 %x, 1 - // %sel = select i1 %cmp, i32 -2147483648, i32 %add - // - // We can't replace %sel with %add unless we strip away the flags. - // TODO: This is an unusual limitation because better analysis results in - // worse simplification. InstCombine can do this fold more generally - // by dropping the flags. Remove this fold to save compile-time? - if (canCreatePoison(cast(I))) - return nullptr; - if (MaxRecurse) { if (B->getOperand(0) == Op) return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q, @@ -3865,6 +3864,13 @@ return nullptr; } +Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + bool AllowRefinement) { + return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, + RecursionLimit); +} + /// Try to simplify a select instruction when its condition operand is an /// integer comparison where one operand of the compare is a constant. static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, @@ -3985,14 +3991,18 @@ // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ false, MaxRecurse) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ false, MaxRecurse) == TrueVal) return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ true, MaxRecurse) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ true, MaxRecurse) == FalseVal) return FalseVal; } Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1149,22 +1149,6 @@ return &Sel; } -static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, - const SimplifyQuery &Q) { - // If this is a binary operator, try to simplify it with the replaced op - // because we know Op and ReplaceOp are equivalant. - // For example: V = X + 1, Op = X, ReplaceOp = 42 - // Simplifies as: add(42, 1) --> 43 - if (auto *BO = dyn_cast(V)) { - if (BO->getOperand(0) == Op) - return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q); - if (BO->getOperand(1) == Op) - return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q); - } - - return nullptr; -} - /// If we have a select with an equality comparison, then we know the value in /// one of the arms of the select. See if substituting this value into an arm /// and simplifying the result yields the same value as the other arm. @@ -1191,20 +1175,45 @@ if (Cmp.getPredicate() == ICmpInst::ICMP_NE) std::swap(TrueVal, FalseVal); + auto *FalseInst = dyn_cast(FalseVal); + if (!FalseInst) + return nullptr; + + // InstSimplify already performed this fold if it was possible subject to + // current poison-generating flags. Try the transform again with + // poison-generating flags temporarily dropped. + bool WasNUW = false, WasNSW = false, WasExact = false; + if (auto *OBO = dyn_cast(FalseVal)) { + WasNUW = OBO->hasNoUnsignedWrap(); + WasNSW = OBO->hasNoSignedWrap(); + FalseInst->setHasNoUnsignedWrap(false); + FalseInst->setHasNoSignedWrap(false); + } + if (auto *PEO = dyn_cast(FalseVal)) { + WasExact = PEO->isExact(); + FalseInst->setIsExact(false); + } + // Try each equivalence substitution possibility. // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 - // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43 Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); - if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal || - simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal || - simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal || - simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) { - if (auto *FalseInst = dyn_cast(FalseVal)) - FalseInst->dropPoisonGeneratingFlags(); + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ false) == TrueVal || + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ false) == TrueVal) { return FalseVal; } + + // Restore poison-generating flags if the transform did not apply. + if (WasNUW) + FalseInst->setHasNoUnsignedWrap(); + if (WasNSW) + FalseInst->setHasNoSignedWrap(); + if (WasExact) + FalseInst->setIsExact(); + return nullptr; } Index: llvm/test/Transforms/InstCombine/select.ll =================================================================== --- llvm/test/Transforms/InstCombine/select.ll +++ llvm/test/Transforms/InstCombine/select.ll @@ -1924,8 +1924,8 @@ ; CHECK: if.false.3: ; CHECK-NEXT: br label [[MERGE_3]] ; CHECK: merge.3: -; CHECK-NEXT: [[S_3:%.*]] = phi i32 [ [[Y:%.*]], [[IF_FALSE_3]] ], [ [[X:%.*]], [[IF_TRUE_3]] ] -; CHECK-NEXT: [[SUM_2:%.*]] = mul i32 [[S_3]], 3 +; CHECK-NEXT: [[S_1:%.*]] = phi i32 [ [[Y:%.*]], [[IF_FALSE_3]] ], [ [[X:%.*]], [[IF_TRUE_3]] ] +; CHECK-NEXT: [[SUM_2:%.*]] = mul i32 [[S_1]], 3 ; CHECK-NEXT: ret i32 [[SUM_2]] ; entry: @@ -2587,3 +2587,20 @@ call void @use_i1_i32(i1 %c.fr, i32 %v) ret void } + +define i32 @pr47322_more_poisonous_replacement(i32 %arg) { +; CHECK-LABEL: @pr47322_more_poisonous_replacement( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[ARG:%.*]], 0 +; CHECK-NEXT: [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG]], i1 immarg true), [[RNG0:!range !.*]] +; CHECK-NEXT: [[SHIFTED:%.*]] = lshr i32 [[ARG]], [[TRAILING]] +; CHECK-NEXT: [[R1_SROA_0_1:%.*]] = select i1 [[CMP]], i32 0, i32 [[SHIFTED]] +; CHECK-NEXT: ret i32 [[R1_SROA_0_1]] +; + %cmp = icmp eq i32 %arg, 0 + %trailing = call i32 @llvm.cttz.i32(i32 %arg, i1 immarg true) + %shifted = lshr i32 %arg, %trailing + %r1.sroa.0.1 = select i1 %cmp, i32 0, i32 %shifted + ret i32 %r1.sroa.0.1 +} + +declare i32 @llvm.cttz.i32(i32, i1 immarg)