diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -526,7 +526,7 @@ bool FoldWithMultiUse = false); /// This is a convenience wrapper function for the above two functions. - Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I); + Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I, bool SelectOnly=false); Instruction *foldAddWithConstant(BinaryOperator &Add); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1000,17 +1000,18 @@ if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - // If the divisor is a select-of-constants, try to constant fold all div ops: - // C / (select Cond, TrueC, FalseC) --> select Cond, (C / TrueC), (C / FalseC) + const APInt *C2; + + // Try to fold the `div` into select/phi nodes. For select we proceed if we + // will be able to do some degree of constant foldings. This doesn't require + // constant operands as `foldBinOpIntoSelectOrPhi` will use the dominating + // select condition to deduce constants in the context of the select inst. + // Only try to fold with phi nodes if we have constant non-zero denominator. // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) - return R; - } + if (Instruction *R = + foldBinOpIntoSelectOrPhi(I, /*SelectOnly*/!match(Op1, m_APInt(C2)) || C2->isZero())) + return R; - const APInt *C2; if (match(Op1, m_APInt(C2))) { Value *X; const APInt *C1; @@ -1089,10 +1090,6 @@ return BinaryOperator::CreateNUWAdd(X, ConstantInt::get(Ty, C1->udiv(*C2))); } - - if (!C2->isZero()) // avoid X udiv 0 - if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) - return FoldedDiv; } if (match(Op0, m_One())) { @@ -1834,15 +1831,14 @@ if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - // If the divisor is a select-of-constants, try to constant fold all rem ops: - // C % (select Cond, TrueC, FalseC) --> select Cond, (C % TrueC), (C % FalseC) + // Try to fold the `div` into select nodes. For select we proceed if we + // will be able to do some degree of constant foldings. This doesn't require + // constant operands as `foldBinOpIntoSelectOrPhi` will use the dominating + // select condition to deduce constants in the context of the select inst. + // Only try to fold with phi nodes if we have constant non-zero denominator. // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) - return R; - } + if (Instruction *R = foldBinOpIntoSelectOrPhi(I, /*SelectOnly*/ true)) + return R; if (isa(Op1)) { if (Instruction *Op0I = dyn_cast(Op0)) { diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1481,14 +1481,51 @@ return NewPhi; } -Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { - if (!isa(I.getOperand(1))) +// Return std::nullopt if we should not fold. Return true if we should fold +// multi-use select and false for single-use select. +static std::optional shouldFoldOpIntoSelect(BinaryOperator &I, Value *Op, + Value *OpOther) { + if (isa(Op)) + // If we will be able to constant fold the incorperated binop, then + // multi-use. Otherwise single-use. + return match(OpOther, m_ImmConstant()) && + match(Op, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant())); + + return std::nullopt; +} + +Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I, + bool SelectOnly) { + + std::optional CanSpeculativelyExecuteRes; + auto CanSpeculativelyExecute = [&]() -> bool { + if (!CanSpeculativelyExecuteRes) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + CanSpeculativelyExecuteRes = + isSafeToSpeculativelyExecute(&I, Q.CxtI, Q.AC, Q.DT, &TLI); + } + return *CanSpeculativelyExecuteRes; + }; + + for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { + // Slightly more involved logic for select. For select we use the condition + // to to infer information about the arm. This allows us to constant-fold + // even when the select arm(s) are not constant. For example if we have: `(X + // == 10 ? 19 : Y) * X`, we can entirely contant fold the true arm as `X == + // 10` dominates it. So we end up with `X == 10 ? 190 : (X * Y))`. + if (auto MultiUse = shouldFoldOpIntoSelect(I, I.getOperand(OpIdx), + I.getOperand(1 - OpIdx))) { + if (*MultiUse || CanSpeculativelyExecute()) + if (Instruction *NewSel = FoldOpIntoSelect( + I, cast(I.getOperand(OpIdx)), *MultiUse)) + return NewSel; + } + } + + if (!isa(I.getOperand(1)) || SelectOnly) return nullptr; - if (auto *Sel = dyn_cast(I.getOperand(0))) { - if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) - return NewSel; - } else if (auto *PN = dyn_cast(I.getOperand(0))) { + if (auto *PN = dyn_cast(I.getOperand(0))) { if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) return NewPhi; } diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll --- a/llvm/test/Transforms/InstCombine/binop-select.ll +++ b/llvm/test/Transforms/InstCombine/binop-select.ll @@ -275,7 +275,7 @@ ; CHECK-LABEL: @and_sel_op0_use( ; CHECK-NEXT: [[S:%.*]] = select i1 [[B:%.*]], i32 25, i32 0 ; CHECK-NEXT: call void @use(i32 [[S]]) -; CHECK-NEXT: [[R:%.*]] = and i32 [[S]], 1 +; CHECK-NEXT: [[R:%.*]] = zext i1 [[B]] to i32 ; CHECK-NEXT: ret i32 [[R]] ; %s = select i1 %b, i32 25, i32 0 @@ -631,8 +631,8 @@ define i32 @test_mul_to_const_Cmul(i32 %x) { ; CHECK-LABEL: @test_mul_to_const_Cmul( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 61 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 9, i32 14 -; CHECK-NEXT: [[DIV:%.*]] = mul i32 [[COND]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[X]], 14 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 549, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 61 diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll --- a/llvm/test/Transforms/InstCombine/extractelement.ll +++ b/llvm/test/Transforms/InstCombine/extractelement.ll @@ -800,10 +800,7 @@ define i32 @extelt_select_const_operand_extractelt_use(i1 %c) { ; ANY-LABEL: @extelt_select_const_operand_extractelt_use( -; ANY-NEXT: [[E:%.*]] = select i1 [[C:%.*]], i32 4, i32 7 -; ANY-NEXT: [[M:%.*]] = shl nuw nsw i32 [[E]], 1 -; ANY-NEXT: [[M_2:%.*]] = shl nuw nsw i32 [[E]], 2 -; ANY-NEXT: [[R:%.*]] = mul nuw nsw i32 [[M]], [[M_2]] +; ANY-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i32 128, i32 392 ; ANY-NEXT: ret i32 [[R]] ; %s = select i1 %c, <3 x i32> , <3 x i32>