diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -526,7 +526,7 @@ bool FoldWithMultiUse = false); /// This is a convenience wrapper function for the above two functions. - Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); + Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I, bool SelectOnly=false); Instruction *foldAddWithConstant(BinaryOperator &Add); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1000,17 +1000,18 @@ if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - // If the divisor is a select-of-constants, try to constant fold all div ops: - // C / (select Cond, TrueC, FalseC) --> select Cond, (C / TrueC), (C / FalseC) + const APInt *C2; + + // Try to fold the `div` into select/phi nodes. For select we proceed if we + // will be able to do some degree of constant foldings. This doesn't require + // constant operands as `foldBinOpIntoSelectOrPhi` will use the dominating + // select condition to deduce constants in the context of the select inst. + // Only try to fold with phi nodes if we have constant non-zero denominator. // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) - return R; - } + if (Instruction *R = + foldBinOpIntoSelectOrPhi(I, /*SelectOnly*/!match(Op1, m_APInt(C2)) || C2->isZero())) + return R; - const APInt *C2; if (match(Op1, m_APInt(C2))) { Value *X; const APInt *C1; @@ -1089,10 +1090,6 @@ return BinaryOperator::CreateNUWAdd(X, ConstantInt::get(Ty, C1->udiv(*C2))); } - - if (!C2->isZero()) // avoid X udiv 0 - if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) - return FoldedDiv; } if (match(Op0, m_One())) { @@ -1834,15 +1831,14 @@ if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - // If the divisor is a select-of-constants, try to constant fold all rem ops: - // C % (select Cond, TrueC, FalseC) --> select Cond, (C % TrueC), (C % FalseC) + // Try to fold the `div` into select nodes. For select we proceed if we + // will be able to do some degree of constant foldings. This doesn't require + // constant operands as `foldBinOpIntoSelectOrPhi` will use the dominating + // select condition to deduce constants in the context of the select inst. + // Only try to fold with phi nodes if we have constant non-zero denominator. // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) - return R; - } + if (Instruction *R = foldBinOpIntoSelectOrPhi(I, /*SelectOnly*/ true)) + return R; if (isa(Op1)) { if (Instruction *Op0I = dyn_cast(Op0)) { diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1481,14 +1481,92 @@ return NewPhi; } -Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { - if (!isa(I.getOperand(1))) +// Return std::nullopt if we should not fold. Return true if we should fold +// multi-use select and false for single-use select. +static std::optional shouldFoldOpIntoSelect(BinaryOperator &I, Value *Op, + Value *OpOther) { + if (isa(Op)) + // If we will be able to constant fold the incorperated binop, then + // multi-use. Otherwise single-use. + return match(OpOther, m_ImmConstant()) && + match(Op, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant())); + + return std::nullopt; +} + +Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I, + bool SelectOnly) { + + std::optional CanSpeculativelyExecuteRes; + auto CanSpeculativelyExecute = [&]() -> bool { + if (!CanSpeculativelyExecuteRes) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + CanSpeculativelyExecuteRes = + isSafeToSpeculativelyExecute(&I, Q.CxtI, Q.AC, Q.DT, &TLI); + // isSafeToSpeculativelyExecute doesn't look through knownbits to see if + // div/rem are speculatable. This is non-trivial to generically implement + // in isSafeToSpeculativelyExecute as some of its users expect to be able + // to perform transformations that invalidate the knownbits analysis and + // maintain the isSafeToSpeculativelyExecute property. So for now, + // implement knownbits checks here. + // TODO: If isSafeToSpeculativelyExecute is updated to use knownbits + // analysis for div/rem remove this code. + if (!(*CanSpeculativelyExecuteRes)) { + if (!isGuaranteedNotToBePoison(I.getOperand(1), Q.AC, Q.CxtI, Q.DT)) + return false; + switch (I.getOpcode()) { + case Instruction::UDiv: + case Instruction::URem: + // Unsigned is safe as long as denominator is non-poison and non-zero. + CanSpeculativelyExecuteRes = isKnownNonZero( + I.getOperand(1), DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT); + break; + case Instruction::SDiv: + case Instruction::SRem: { + KnownBits KnownDenom = llvm::computeKnownBits( + I.getOperand(1), DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT); + if (!(KnownDenom.isNonZero() || + isKnownNonZero(I.getOperand(1), DL, /*Depth*/ 0, Q.AC, Q.CxtI, + Q.DT))) + return false; + + KnownBits KnownNum = llvm::computeKnownBits( + I.getOperand(0), DL, /*Depth*/ 0, Q.AC, Q.CxtI, Q.DT); + // Signed is safe if either denominator is (non-poison and non-zero) + // OR (denominator != -1 or numerator is non-poison and numerator != + // INT_MIN). + CanSpeculativelyExecuteRes = + !KnownDenom.getMaxValue().isAllOnes() || + (isGuaranteedNotToBePoison(I.getOperand(0), Q.AC, Q.CxtI, Q.DT) && + KnownNum.getSignedMinValue().isMinSignedValue()); + } break; + default: + break; + } + } + } + return *CanSpeculativelyExecuteRes; + }; + + for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { + // Slightly more involved logic for select. For select we use the condition + // to to infer information about the arm. This allows us to constant-fold + // even when the select arm(s) are not constant. For example if we have: `(X + // == 10 ? 19 : Y) * X`, we can entirely contant fold the true arm as `X == + // 10` dominates it. So we end up with `X == 10 ? 190 : (X * Y))`. + if (auto MultiUse = shouldFoldOpIntoSelect(I, I.getOperand(OpIdx), + I.getOperand(1 - OpIdx))) { + if (*MultiUse || CanSpeculativelyExecute()) + if (Instruction *NewSel = FoldOpIntoSelect( + I, cast(I.getOperand(OpIdx)), *MultiUse)) + return NewSel; + } + } + + if (!isa(I.getOperand(1)) || SelectOnly) return nullptr; - if (auto *Sel = dyn_cast(I.getOperand(0))) { - if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) - return NewSel; - } else if (auto *PN = dyn_cast(I.getOperand(0))) { + if (auto *PN = dyn_cast(I.getOperand(0))) { if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) return NewPhi; } diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll --- a/llvm/test/Transforms/InstCombine/binop-select.ll +++ b/llvm/test/Transforms/InstCombine/binop-select.ll @@ -275,7 +275,7 @@ ; CHECK-LABEL: @and_sel_op0_use( ; CHECK-NEXT: [[S:%.*]] = select i1 [[B:%.*]], i32 25, i32 0 ; CHECK-NEXT: call void @use(i32 [[S]]) -; CHECK-NEXT: [[R:%.*]] = and i32 [[S]], 1 +; CHECK-NEXT: [[R:%.*]] = zext i1 [[B]] to i32 ; CHECK-NEXT: ret i32 [[R]] ; %s = select i1 %b, i32 25, i32 0 @@ -407,9 +407,8 @@ define <2 x i32> @test_udiv_to_const_shr(i1 %c, <2 x i32> %x, <2 x i32> %yy) { ; CHECK-LABEL: @test_udiv_to_const_shr( -; CHECK-NEXT: [[Y:%.*]] = shl nuw <2 x i32> , [[YY:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C:%.*]], <2 x i32> , <2 x i32> [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = udiv <2 x i32> , [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> , [[YY:%.*]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C:%.*]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %y = shl <2 x i32> , %yy @@ -436,8 +435,8 @@ define i32 @test_udiv_to_const_Cudiv(i32 %x) { ; CHECK-LABEL: @test_udiv_to_const_Cudiv( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 90 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 19 -; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[X]], 19 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 12, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 90 @@ -468,8 +467,8 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i32 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 90 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 19, i32 [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[X]], [[Y]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 4, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %ynz = icmp ne i32 %y, 0 @@ -509,8 +508,8 @@ define <2 x i32> @test_sdiv_to_const_Csdiv(<2 x i32> %x) { ; CHECK-LABEL: @test_sdiv_to_const_Csdiv( ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> -; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i32> [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[X]], +; CHECK-NEXT: [[DIV:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %c = icmp ne <2 x i32> %x, @@ -522,8 +521,8 @@ define <2 x i32> @test_sdiv_to_const_Csdiv_todo_no_common_bit(<2 x i32> %x) { ; CHECK-LABEL: @test_sdiv_to_const_Csdiv_todo_no_common_bit( ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> -; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i32> [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[X]], +; CHECK-NEXT: [[DIV:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %c = icmp ne <2 x i32> %x, @@ -535,8 +534,8 @@ define i32 @test_srem_to_const_Csrem(i32 %x) { ; CHECK-LABEL: @test_srem_to_const_Csrem( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 18, i32 16 -; CHECK-NEXT: [[DIV:%.*]] = srem i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i32 [[X]], 16 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 6, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 24 @@ -548,8 +547,8 @@ define i32 @test_srem_to_const_Csrem_todo_no_common_bits(i32 %x) { ; CHECK-LABEL: @test_srem_to_const_Csrem_todo_no_common_bits( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 16 -; CHECK-NEXT: [[DIV:%.*]] = srem i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i32 [[X]], 16 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 3, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 24 @@ -576,8 +575,7 @@ define i32 @test_srem_fail_no_speculation(i32 %x) { ; CHECK-LABEL: @test_srem_fail_no_speculation( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 -1 -; CHECK-NEXT: [[DIV:%.*]] = srem i32 [[X]], [[COND]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 3, i32 0 ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 24 @@ -589,10 +587,11 @@ define i32 @test_urem_to_const_and_ind_x(i32 %x, i32 %yy) { ; CHECK-LABEL: @test_urem_to_const_and_ind_x( -; CHECK-NEXT: [[Y:%.*]] = shl nuw i32 1, [[YY:%.*]] ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C_NOT]], i32 19, i32 [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = urem i32 [[X]], [[COND]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[YY:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[X]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C_NOT]], i32 5, i32 [[TMP2]] ; CHECK-NEXT: ret i32 [[DIV]] ; %y = shl i32 1, %yy @@ -604,9 +603,10 @@ define i32 @test_urem_to_const_and(i1 %c, i32 %yy) { ; CHECK-LABEL: @test_urem_to_const_and( -; CHECK-NEXT: [[Y:%.*]] = shl nuw i32 1, [[YY:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C:%.*]], i32 [[Y]], i32 19 -; CHECK-NEXT: [[DIV:%.*]] = urem i32 44, [[COND]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[YY:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[NOTMASK]], 44 +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], 44 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C:%.*]], i32 [[TMP2]], i32 6 ; CHECK-NEXT: ret i32 [[DIV]] ; %y = shl i32 1, %yy @@ -631,8 +631,8 @@ define i32 @test_mul_to_const_Cmul(i32 %x) { ; CHECK-LABEL: @test_mul_to_const_Cmul( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 61 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 9, i32 14 -; CHECK-NEXT: [[DIV:%.*]] = mul i32 [[COND]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[X]], 14 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 549, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 61 diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll --- a/llvm/test/Transforms/InstCombine/extractelement.ll +++ b/llvm/test/Transforms/InstCombine/extractelement.ll @@ -800,10 +800,7 @@ define i32 @extelt_select_const_operand_extractelt_use(i1 %c) { ; ANY-LABEL: @extelt_select_const_operand_extractelt_use( -; ANY-NEXT: [[E:%.*]] = select i1 [[C:%.*]], i32 4, i32 7 -; ANY-NEXT: [[M:%.*]] = shl nuw nsw i32 [[E]], 1 -; ANY-NEXT: [[M_2:%.*]] = shl nuw nsw i32 [[E]], 2 -; ANY-NEXT: [[R:%.*]] = mul nuw nsw i32 [[M]], [[M_2]] +; ANY-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i32 128, i32 392 ; ANY-NEXT: ret i32 [[R]] ; %s = select i1 %c, <3 x i32> , <3 x i32>