diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -188,6 +188,27 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, bool AssumeNonZero, bool DoFold); +// Return std::nullopt if we should not fold. Return true if we should fold +// multi-use select and false for single-use select. +static std::optional +shouldFoldOpIntoSelect(BinaryOperator &I, Value *Op, Value *OpOther, + const SimplifyQuery *Q = nullptr, + const TargetLibraryInfo *TLI = nullptr) { + assert((Q == nullptr) == (TLI == nullptr)); + if (isa(Op)) { + // If we will be able to constant fold the incorperated binop, then + // multi-use. Otherwise single-use. + if (match(OpOther, m_ImmConstant()) && + match(Op, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) + return true; + else if (!Q || !TLI || + isSafeToSpeculativelyExecute(&I, Q->CxtI, Q->AC, Q->DT, TLI)) + return false; + } + + return std::nullopt; +} + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = @@ -505,6 +526,14 @@ return Shl; } + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op0, Op1)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op0), *MultiUse)) + return R; + + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op1, Op0)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), *MultiUse)) + return R; + bool Changed = false; if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; @@ -1000,15 +1029,19 @@ if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - // If the divisor is a select-of-constants, try to constant fold all div ops: - // C / (select Cond, TrueC, FalseC) --> select Cond, (C / TrueC), (C / FalseC) + // If the divisor is a select, try to constant division into select. We don't + // need to entirely constant fold the division. If we are able to fold only + // one arm but get a constant divisor (pulling out the non-folded arm) that is + // also preferable. // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op0, Op1, &Q, &TLI)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op0), *MultiUse)) + return R; + + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op1, Op0, &Q, &TLI)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), *MultiUse)) return R; - } const APInt *C2; if (match(Op1, m_APInt(C2))) { @@ -1834,15 +1867,19 @@ if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; - // If the divisor is a select-of-constants, try to constant fold all rem ops: - // C % (select Cond, TrueC, FalseC) --> select Cond, (C % TrueC), (C % FalseC) + // If the divisor is a select, try to constant `rem` op into select. We don't + // need to entirely constant fold the `rem` op. If we are able to fold only + // one arm but get a constant divisor (pulling out the non-folded arm) that is + // also preferable. // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op0, Op1, &Q, &TLI)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op0), *MultiUse)) + return R; + + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op1, Op0, &Q, &TLI)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), *MultiUse)) return R; - } if (isa(Op1)) { if (Instruction *Op0I = dyn_cast(Op0)) { diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll --- a/llvm/test/Transforms/InstCombine/binop-select.ll +++ b/llvm/test/Transforms/InstCombine/binop-select.ll @@ -407,9 +407,8 @@ define <2 x i32> @test_udiv_to_const_shr(i1 %c, <2 x i32> %x, <2 x i32> %yy) { ; CHECK-LABEL: @test_udiv_to_const_shr( -; CHECK-NEXT: [[Y:%.*]] = shl nuw <2 x i32> , [[YY:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C:%.*]], <2 x i32> , <2 x i32> [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = udiv <2 x i32> , [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> , [[YY:%.*]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C:%.*]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %y = shl <2 x i32> , %yy @@ -436,8 +435,8 @@ define i32 @test_udiv_to_const_Cudiv(i32 %x) { ; CHECK-LABEL: @test_udiv_to_const_Cudiv( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 90 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 19 -; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[X]], 19 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 12, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 90 @@ -468,8 +467,8 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i32 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 90 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 19, i32 [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[X]], [[Y]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 4, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %ynz = icmp ne i32 %y, 0 @@ -509,8 +508,8 @@ define <2 x i32> @test_sdiv_to_const_Csdiv(<2 x i32> %x) { ; CHECK-LABEL: @test_sdiv_to_const_Csdiv( ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> -; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i32> [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[X]], +; CHECK-NEXT: [[DIV:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %c = icmp ne <2 x i32> %x, @@ -522,8 +521,8 @@ define <2 x i32> @test_sdiv_to_const_Csdiv_todo_no_common_bit(<2 x i32> %x) { ; CHECK-LABEL: @test_sdiv_to_const_Csdiv_todo_no_common_bit( ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> -; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i32> [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[X]], +; CHECK-NEXT: [[DIV:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %c = icmp ne <2 x i32> %x, @@ -535,8 +534,8 @@ define i32 @test_srem_to_const_Csrem(i32 %x) { ; CHECK-LABEL: @test_srem_to_const_Csrem( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 18, i32 16 -; CHECK-NEXT: [[DIV:%.*]] = srem i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i32 [[X]], 16 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 6, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 24 @@ -548,8 +547,8 @@ define i32 @test_srem_to_const_Csrem_todo_no_common_bits(i32 %x) { ; CHECK-LABEL: @test_srem_to_const_Csrem_todo_no_common_bits( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 16 -; CHECK-NEXT: [[DIV:%.*]] = srem i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i32 [[X]], 16 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 3, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 24 @@ -589,10 +588,11 @@ define i32 @test_urem_to_const_and_ind_x(i32 %x, i32 %yy) { ; CHECK-LABEL: @test_urem_to_const_and_ind_x( -; CHECK-NEXT: [[Y:%.*]] = shl nuw i32 1, [[YY:%.*]] ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C_NOT]], i32 19, i32 [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = urem i32 [[X]], [[COND]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[YY:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[X]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C_NOT]], i32 5, i32 [[TMP2]] ; CHECK-NEXT: ret i32 [[DIV]] ; %y = shl i32 1, %yy @@ -604,9 +604,10 @@ define i32 @test_urem_to_const_and(i1 %c, i32 %yy) { ; CHECK-LABEL: @test_urem_to_const_and( -; CHECK-NEXT: [[Y:%.*]] = shl nuw i32 1, [[YY:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C:%.*]], i32 [[Y]], i32 19 -; CHECK-NEXT: [[DIV:%.*]] = urem i32 44, [[COND]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[YY:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[NOTMASK]], 44 +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], 44 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C:%.*]], i32 [[TMP2]], i32 6 ; CHECK-NEXT: ret i32 [[DIV]] ; %y = shl i32 1, %yy @@ -631,8 +632,8 @@ define i32 @test_mul_to_const_Cmul(i32 %x) { ; CHECK-LABEL: @test_mul_to_const_Cmul( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 61 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 9, i32 14 -; CHECK-NEXT: [[DIV:%.*]] = mul i32 [[COND]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[X]], 14 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 549, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 61