diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -188,6 +188,15 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, bool DoFold); +static std::optional shouldFoldOpIntoSelect(BinaryOperator &I, Value *Op, + Value *OpOther) { + if (match(Op, m_Select(m_Value(), m_Value(), m_Value()))) + return match(OpOther, m_ImmConstant()) && + match(Op, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant())); + + return std::nullopt; +} + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = @@ -501,6 +510,14 @@ return Shl; } + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op0, Op1)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op0), *MultiUse)) + return R; + + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op1, Op0)) + if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), *MultiUse)) + return R; + bool Changed = false; if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; @@ -999,11 +1016,17 @@ // If the divisor is a select-of-constants, try to constant fold all div ops: // C / (select Cond, TrueC, FalseC) --> select Cond, (C / TrueC), (C / FalseC) // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) - return R; + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (isKnownNonZero(Op1, DL, 0, Q.AC, Q.CxtI, Q.DT)) { + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op0, Op1)) + if (Instruction *R = + FoldOpIntoSelect(I, cast(Op0), *MultiUse)) + return R; + + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op1, Op0)) + if (Instruction *R = + FoldOpIntoSelect(I, cast(Op1), *MultiUse)) + return R; } const APInt *C2; @@ -1820,11 +1843,17 @@ // If the divisor is a select-of-constants, try to constant fold all rem ops: // C % (select Cond, TrueC, FalseC) --> select Cond, (C % TrueC), (C % FalseC) // TODO: Adapt simplifyDivRemOfSelectWithZeroOp to allow this and other folds. - if (match(Op0, m_ImmConstant()) && - match(Op1, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()))) { - if (Instruction *R = FoldOpIntoSelect(I, cast(Op1), - /*FoldWithMultiUse*/ true)) - return R; + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (isKnownNonZero(Op1, DL, 0, Q.AC, Q.CxtI, Q.DT)) { + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op0, Op1)) + if (Instruction *R = + FoldOpIntoSelect(I, cast(Op0), *MultiUse)) + return R; + + if (auto MultiUse = shouldFoldOpIntoSelect(I, Op1, Op0)) + if (Instruction *R = + FoldOpIntoSelect(I, cast(Op1), *MultiUse)) + return R; } if (isa(Op1)) { diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll --- a/llvm/test/Transforms/InstCombine/binop-select.ll +++ b/llvm/test/Transforms/InstCombine/binop-select.ll @@ -352,9 +352,8 @@ define <2 x i32> @test_udiv_to_const_shr(i1 %c, <2 x i32> %x, <2 x i32> %yy) { ; CHECK-LABEL: @test_udiv_to_const_shr( -; CHECK-NEXT: [[Y:%.*]] = shl nuw <2 x i32> , [[YY:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C:%.*]], <2 x i32> , <2 x i32> [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = udiv <2 x i32> , [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> , [[YY:%.*]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C:%.*]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %y = shl <2 x i32> , %yy @@ -381,8 +380,8 @@ define i32 @test_udiv_to_const_Cudiv(i32 %x) { ; CHECK-LABEL: @test_udiv_to_const_Cudiv( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 90 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 19 -; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[X]], 19 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 12, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 90 @@ -413,8 +412,8 @@ ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i32 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 90 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 19, i32 [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = udiv i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = udiv i32 [[X]], [[Y]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 4, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %ynz = icmp ne i32 %y, 0 @@ -454,8 +453,8 @@ define <2 x i32> @test_sdiv_to_const_Csdiv(<2 x i32> %x) { ; CHECK-LABEL: @test_sdiv_to_const_Csdiv( ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[COND:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> -; CHECK-NEXT: [[DIV:%.*]] = sdiv <2 x i32> [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = sdiv <2 x i32> [[X]], +; CHECK-NEXT: [[DIV:%.*]] = select <2 x i1> [[C_NOT]], <2 x i32> , <2 x i32> [[TMP1]] ; CHECK-NEXT: ret <2 x i32> [[DIV]] ; %c = icmp ne <2 x i32> %x, @@ -467,8 +466,8 @@ define i32 @test_srem_to_const_Csrem(i32 %x) { ; CHECK-LABEL: @test_srem_to_const_Csrem( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 7, i32 16 -; CHECK-NEXT: [[DIV:%.*]] = srem i32 [[X]], [[COND]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i32 [[X]], 16 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 3, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 24 @@ -494,10 +493,11 @@ define i32 @test_urem_to_const_and_ind_x(i32 %x, i32 %yy) { ; CHECK-LABEL: @test_urem_to_const_and_ind_x( -; CHECK-NEXT: [[Y:%.*]] = shl nuw i32 1, [[YY:%.*]] ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq i32 [[X:%.*]], 24 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C_NOT]], i32 19, i32 [[Y]] -; CHECK-NEXT: [[DIV:%.*]] = urem i32 [[X]], [[COND]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[YY:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i32 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[TMP1]], [[X]] +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C_NOT]], i32 5, i32 [[TMP2]] ; CHECK-NEXT: ret i32 [[DIV]] ; %y = shl i32 1, %yy @@ -509,9 +509,10 @@ define i32 @test_urem_to_const_and(i1 %c, i32 %yy) { ; CHECK-LABEL: @test_urem_to_const_and( -; CHECK-NEXT: [[Y:%.*]] = shl nuw i32 1, [[YY:%.*]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C:%.*]], i32 [[Y]], i32 19 -; CHECK-NEXT: [[DIV:%.*]] = urem i32 44, [[COND]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i32 -1, [[YY:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[NOTMASK]], 44 +; CHECK-NEXT: [[TMP2:%.*]] = xor i32 [[TMP1]], 44 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C:%.*]], i32 [[TMP2]], i32 6 ; CHECK-NEXT: ret i32 [[DIV]] ; %y = shl i32 1, %yy @@ -536,8 +537,8 @@ define i32 @test_mul_to_const_Cmul(i32 %x) { ; CHECK-LABEL: @test_mul_to_const_Cmul( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 61 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 9, i32 14 -; CHECK-NEXT: [[DIV:%.*]] = mul i32 [[COND]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[X]], 14 +; CHECK-NEXT: [[DIV:%.*]] = select i1 [[C]], i32 549, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[DIV]] ; %c = icmp eq i32 %x, 61