diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -240,8 +240,7 @@ } // X udiv (C1 << N), where C1 is "1< X >> (N+C2) - if (match(Op1, m_Shl(m_Power2(), m_Value())) || - match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) { + if (match(Op1, m_ZExtOrSelf(m_Shl(m_Power2(), m_Value())))) { Actions.push_back(OperandFoldAction(foldUDivShl, Op1)); return Actions.size(); } @@ -261,6 +260,94 @@ return 0; } +// X mul 2^C -> X << C +static Instruction *foldMulPow2Cst(Value *Op0, Value *Op1, + const BinaryOperator &I, InstCombiner &IC) { + Constant *C1 = cast(Op1); + Constant *LogC1 = getLogBase2(Op0->getType(), C1); + // An undef value becomes zero because max bit shift can produce poisoned + // value. + LogC1 = Constant::replaceUndefsWith( + LogC1, ConstantInt::get(LogC1->getType()->getScalarType(), 0)); + assert(LogC1 && "Failed to constant fold mul -> logbase2"); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, LogC1); + if (I.hasNoUnsignedWrap()) + Shl->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && C1->isNotMinSignedValue()) + Shl->setHasNoSignedWrap(); + return Shl; +} + +// The caller must guarantee nsw/nuw properties. +// X mul (C1 << N), where C1 is "1< X << (add N, C2) + +// X mul (zext (C1 << N)), where C1 is "1< X << (zext (add N, C2)) +static Instruction *foldMulShl(Value *Op0, Value *Op1, const BinaryOperator &I, + InstCombiner &IC) { + Value *ShiftLeft; + if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) + ShiftLeft = Op1; + + Constant *CI; + Value *N; + if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) + assert(false && "match should never fail in foldMulShl"); + Constant *Log2Base = getLogBase2(N->getType(), CI); + assert(Log2Base && "getLogBase2 should never fail"); + if (!Log2Base->isZeroValue()) + N = IC.Builder.CreateAdd(N, Log2Base); + if (Op1 != ShiftLeft) + N = IC.Builder.CreateZExt(N, Op1->getType()); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, N); + if (I.hasNoUnsignedWrap()) + Shl->setHasNoUnsignedWrap(); + if (I.hasNoSignedWrap() && cast(ShiftLeft)->hasNoSignedWrap()) + Shl->setHasNoSignedWrap(); + return Shl; +} + +// Recursively visits the possible right hand operands of a mul +// instruction, seeing through select instructions, to determine if we can +// replace the mul with something simpler. If we find that an operand is not +// able to simplify the mul, we abort the entire transformation. +static size_t visitMulOperand(Value *Op0, Value *Op1, + SmallVectorImpl &Actions, + unsigned Depth = 0) { + // Check to see if this is a multiplication with an exact power of 2, if so, + // convert to a left shift. + if (match(Op1, m_Power2())) { + Actions.emplace_back(foldMulPow2Cst, Op1); + return Actions.size(); + } + + // X mul (C1 << N), where C1 is "1< X << (add N, C2) + + // X mul (zext (C1 << N)), where C1 is "1< X << (zext (add N, C2)) + if (match(Op1, m_ZExtOrSelf(m_Shl(m_One(), m_Value()))) || + match(Op1, m_ZExtOrSelf(m_NSWShl(m_Power2(), m_Value()))) || + match(Op1, m_ZExtOrSelf(m_NUWShl(m_Power2(), m_Value())))) { + Actions.emplace_back(foldMulShl, Op1); + return Actions.size(); + } + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return 0; + + if (SelectInst *SI = dyn_cast(Op1)) + if (size_t LHSIdx = visitMulOperand(Op0, SI->getOperand(1), Actions, Depth)) + if (visitMulOperand(Op0, SI->getOperand(2), Actions, Depth)) { + Actions.emplace_back(nullptr, Op1, LHSIdx - 1); + return Actions.size(); + } + + return 0; +} + // TODO: This is a specific form of a much more general pattern. // We could detect a select with any binop identity constant, or we // could use SimplifyBinOp to see if either arm of the select reduces. @@ -397,6 +484,17 @@ if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) return replaceInstUsesWith(I, FoldedMul); + // (LHS mul (select (select (...)))) -> (LHS << (select (select (...)))) + SmallVector MulActions; + if (visitMulOperand(Op0, Op1, MulActions)) + if (Instruction *Inst = combineActions(Op0, I, MulActions, *this)) + return Inst; + MulActions.clear(); + // ((select (select (...))) mul LHS) -> (LHS << (select (select (...)))) + if (visitMulOperand(Op1, Op0, MulActions)) + if (Instruction *Inst = combineActions(Op1, I, MulActions, *this)) + return Inst; + // Simplify mul instructions with a constant RHS. if (isa(Op1)) { // Canonicalize (X+C1)*CI -> X*CI+C1*CI. @@ -469,28 +567,6 @@ if (I.getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateAnd(Op0, Op1); - // X*(1 << Y) --> X << Y - // (1 << Y)*X --> X << Y - { - Value *Y; - BinaryOperator *BO = nullptr; - bool ShlNSW = false; - if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { - BO = BinaryOperator::CreateShl(Op1, Y); - ShlNSW = cast(Op0)->hasNoSignedWrap(); - } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { - BO = BinaryOperator::CreateShl(Op0, Y); - ShlNSW = cast(Op1)->hasNoSignedWrap(); - } - if (BO) { - if (I.hasNoUnsignedWrap()) - BO->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap() && ShlNSW) - BO->setHasNoSignedWrap(); - return BO; - } - } - // (bool X) * Y --> X ? Y : 0 // Y * (bool X) --> X ? Y : 0 if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -609,12 +609,12 @@ } -; TODO. 'select + mul' -> 'select + shl' for power of twos +; 'select + mul' -> 'select + shl' for power of twos define i32 @shift_if_power2(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_power2( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 16, i32 4 -; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], i32 4, i32 2 +; CHECK-NEXT: [[R:%.*]] = shl i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 16, i32 4 @@ -624,8 +624,8 @@ define i32 @shift_if_power2_nuw(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_power2_nuw( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 16, i32 4 -; CHECK-NEXT: [[R:%.*]] = mul nuw i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], i32 4, i32 2 +; CHECK-NEXT: [[R:%.*]] = shl nuw i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 16, i32 4 @@ -635,8 +635,8 @@ define i32 @shift_if_power2_nsw(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_power2_nsw( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 16, i32 4 -; CHECK-NEXT: [[R:%.*]] = mul nsw i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], i32 4, i32 2 +; CHECK-NEXT: [[R:%.*]] = shl nsw i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 16, i32 4 @@ -646,8 +646,8 @@ define i32 @shift_if_power2_nuw_nsw(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_power2_nuw_nsw( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 16, i32 4 -; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], i32 4, i32 2 +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 16, i32 4 @@ -657,8 +657,8 @@ define i32 @shift_if_power2_nuw_nsw_min(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_power2_nuw_nsw_min( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 2, i32 -2147483648 -; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], i32 1, i32 31 +; CHECK-NEXT: [[R:%.*]] = shl nuw i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 2, i32 -2147483648 @@ -683,11 +683,11 @@ define i32 @shift_if_power2_double_select_zext(i32 %x, i16 %y, i1 %cond1, i1 %cond2) { ; CHECK-LABEL: @shift_if_power2_double_select_zext( -; CHECK-NEXT: [[SHL_RES:%.*]] = shl nsw i16 8, [[Y:%.*]] -; CHECK-NEXT: [[SHL:%.*]] = zext i16 [[SHL_RES]] to i32 -; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i32 [[SHL]], i32 1024 -; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND2:%.*]], i32 16, i32 [[SEL1]] -; CHECK-NEXT: [[R:%.*]] = mul nuw i32 [[SEL2]], [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[Y:%.*]], 3 +; CHECK-NEXT: [[TMP2:%.*]] = zext i16 [[TMP1]] to i32 +; CHECK-NEXT: [[DOTV:%.*]] = select i1 [[COND1:%.*]], i32 [[TMP2]], i32 10 +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND2:%.*]], i32 4, i32 [[DOTV]] +; CHECK-NEXT: [[R:%.*]] = shl nuw i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %shl.res = shl nsw i16 8, %y @@ -715,13 +715,29 @@ ret i32 %r } +define i32 @shift_if_power2_double_select_zext_shift_1(i32 %x, i16 %y, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @shift_if_power2_double_select_zext_shift_1( +; CHECK-NEXT: [[TMP1:%.*]] = zext i16 [[Y:%.*]] to i32 +; CHECK-NEXT: [[DOTV:%.*]] = select i1 [[COND1:%.*]], i32 [[TMP1]], i32 10 +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND2:%.*]], i32 4, i32 [[DOTV]] +; CHECK-NEXT: [[R:%.*]] = shl i32 [[X:%.*]], [[R_V]] +; CHECK-NEXT: ret i32 [[R]] +; + %shl.res = shl i16 1, %y + %shl = zext i16 %shl.res to i32 + %sel1 = select i1 %cond1, i32 %shl, i32 1024 + %sel2 = select i1 %cond2, i32 16, i32 %sel1 + %r = mul nsw i32 %x, %sel2 + ret i32 %r +} + define i32 @shift_if_power2_double_select_zext_shift_nsw(i32 %x, i16 %y, i1 %cond1, i1 %cond2) { ; CHECK-LABEL: @shift_if_power2_double_select_zext_shift_nsw( -; CHECK-NEXT: [[SHL_RES:%.*]] = shl nsw i16 8, [[Y:%.*]] -; CHECK-NEXT: [[SHL:%.*]] = zext i16 [[SHL_RES]] to i32 -; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i32 [[SHL]], i32 1024 -; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND2:%.*]], i32 16, i32 [[SEL1]] -; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i32 [[SEL2]], [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[Y:%.*]], 3 +; CHECK-NEXT: [[TMP2:%.*]] = zext i16 [[TMP1]] to i32 +; CHECK-NEXT: [[DOTV:%.*]] = select i1 [[COND1:%.*]], i32 [[TMP2]], i32 10 +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND2:%.*]], i32 4, i32 [[DOTV]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %shl.res = shl nsw i16 8, %y @@ -732,10 +748,27 @@ ret i32 %r } +define i32 @shift_if_power2_double_select_zext_shift_nuw(i32 %x, i16 %y, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @shift_if_power2_double_select_zext_shift_nuw( +; CHECK-NEXT: [[TMP1:%.*]] = add i16 [[Y:%.*]], 3 +; CHECK-NEXT: [[TMP2:%.*]] = zext i16 [[TMP1]] to i32 +; CHECK-NEXT: [[DOTV:%.*]] = select i1 [[COND1:%.*]], i32 [[TMP2]], i32 10 +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND2:%.*]], i32 4, i32 [[DOTV]] +; CHECK-NEXT: [[R:%.*]] = shl nuw i32 [[X:%.*]], [[R_V]] +; CHECK-NEXT: ret i32 [[R]] +; + %shl.res = shl nuw i16 8, %y + %shl = zext i16 %shl.res to i32 + %sel1 = select i1 %cond1, i32 %shl, i32 1024 + %sel2 = select i1 %cond2, i32 16, i32 %sel1 + %r = mul nuw nsw i32 %x, %sel2 + ret i32 %r +} + define i32 @shift_if_power2_zero(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_power2_zero( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 1, i32 4 -; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], i32 0, i32 2 +; CHECK-NEXT: [[R:%.*]] = shl i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 1, i32 4 @@ -745,8 +778,8 @@ define <2 x i8> @shift_if_power2_vector(<2 x i8> %px, i1 %cond) { ; CHECK-LABEL: @shift_if_power2_vector( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> -; CHECK-NEXT: [[R:%.*]] = mul <2 x i8> [[SEL]], [[PX:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = shl <2 x i8> [[PX:%.*]], [[R_V]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %sel = select i1 %cond, <2 x i8> , <2 x i8> @@ -754,11 +787,23 @@ ret <2 x i8> %r } +define <2 x i32> @shift_if_power2_vector_nsw(<2 x i32> %px, i1 %cond) { +; CHECK-LABEL: @shift_if_power2_vector_nsw( +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], <2 x i32> , <2 x i32> +; CHECK-NEXT: [[R:%.*]] = shl <2 x i32> [[PX:%.*]], [[R_V]] +; CHECK-NEXT: ret <2 x i32> [[R]] +; + %sel = select i1 %cond, <2 x i32> , <2 x i32> + %r = mul nsw <2 x i32> %px, %sel + ret <2 x i32> %r +} + define i32 @shift_if_extra_use(i32 %x, i1 %cond) { ; CHECK-LABEL: @shift_if_extra_use( ; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 4, i32 128 ; CHECK-NEXT: call void @use32(i32 [[SEL]]) -; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND]], i32 2, i32 7 +; CHECK-NEXT: [[R:%.*]] = shl i32 [[X:%.*]], [[R_V]] ; CHECK-NEXT: ret i32 [[R]] ; %sel = select i1 %cond, i32 4, i32 128 @@ -810,8 +855,8 @@ define <2 x i8> @shift_if_undef_vector(<2 x i8> %px, i1 %cond) { ; CHECK-LABEL: @shift_if_undef_vector( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> -; CHECK-NEXT: [[R:%.*]] = mul <2 x i8> [[SEL]], [[PX:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = shl <2 x i8> [[PX:%.*]], [[R_V]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %sel = select i1 %cond, <2 x i8> , <2 x i8> @@ -821,8 +866,8 @@ define <2 x i8> @shift_if_different_lanes_undef_vector(<2 x i8> %px, i1 %cond) { ; CHECK-LABEL: @shift_if_different_lanes_undef_vector( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> -; CHECK-NEXT: [[R:%.*]] = mul <2 x i8> [[SEL]], [[PX:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = shl <2 x i8> [[PX:%.*]], [[R_V]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %sel = select i1 %cond, <2 x i8> , <2 x i8> @@ -832,8 +877,8 @@ define <2 x i8> @shift_if_same_lane_undef_vector(<2 x i8> %px, i1 %cond) { ; CHECK-LABEL: @shift_if_same_lane_undef_vector( -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> -; CHECK-NEXT: [[R:%.*]] = mul <2 x i8> [[SEL]], [[PX:%.*]] +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = shl <2 x i8> [[PX:%.*]], [[R_V]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %sel = select i1 %cond, <2 x i8> , <2 x i8> diff --git a/llvm/test/Transforms/InstCombine/udiv_select_to_select_shift.ll b/llvm/test/Transforms/InstCombine/udiv_select_to_select_shift.ll --- a/llvm/test/Transforms/InstCombine/udiv_select_to_select_shift.ll +++ b/llvm/test/Transforms/InstCombine/udiv_select_to_select_shift.ll @@ -19,6 +19,35 @@ ret i64 %sum } +define <2 x i8> @test2(<2 x i8> %px, i1 %cond) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: [[R_V:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = lshr <2 x i8> [[PX:%.*]], [[R_V]] +; CHECK-NEXT: ret <2 x i8> [[R]] +; + %sel = select i1 %cond, <2 x i8> , <2 x i8> + %r = udiv <2 x i8> %px, %sel + ret <2 x i8> %r +} + +define <2 x i8> @test3(<2 x i8> %px, i1 %cond) { +; CHECK-LABEL: @test3( +; CHECK-NEXT: ret <2 x i8> undef +; + %sel = select i1 %cond, <2 x i8> , <2 x i8> + %r = udiv <2 x i8> %px, %sel + ret <2 x i8> %r +} + +define <2 x i8> @test4(<2 x i8> %px, i1 %cond) { +; CHECK-LABEL: @test4( +; CHECK-NEXT: ret <2 x i8> undef +; + %sel = select i1 %cond, <2 x i8> , <2 x i8> + %r = udiv <2 x i8> %px, %sel + ret <2 x i8> %r +} + ; https://bugs.llvm.org/show_bug.cgi?id=34856 ; This would assert/crash because we didn't propagate the condition with the correct vector type.