diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -124,6 +124,143 @@ return ConstantVector::get(Elts); } +static const unsigned MaxDepth = 6; + +namespace { + +using FoldOperandCb = Instruction *(*)(Value *Op0, Value *Op1, + const BinaryOperator &I, + InstCombiner &IC); + +/// Used to maintain state for visitMul() and visitUDiv(). +struct OperandFoldAction { + /// Informs visit() how to fold this operand. This can be zero if this + /// action joins two actions together. + FoldOperandCb FoldAction; + + /// Which operand to fold. + Value *OperandToFold; + + union { + /// The instruction returned when FoldAction is invoked. + Instruction *FoldResult; + + /// Stores the LHS action index if this action joins two actions together. + size_t SelectLHSIdx; + }; + + OperandFoldAction(FoldOperandCb FA, Value *InputOperand) + : FoldAction(FA), OperandToFold(InputOperand), FoldResult(nullptr) {} + OperandFoldAction(FoldOperandCb FA, Value *InputOperand, size_t SLHS) + : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {} +}; + +} // end anonymous namespace + +static Instruction *combineActions(Value *Op0, BinaryOperator &I, + SmallVectorImpl &Actions, + InstCombiner &IC) { + for (unsigned i = 0, e = Actions.size(); i != e; ++i) { + FoldOperandCb Action = Actions[i].FoldAction; + Value *ActionOp1 = Actions[i].OperandToFold; + Instruction *Inst; + if (Action) + Inst = Action(Op0, ActionOp1, I, IC); + else { + // This action joins two actions together. The RHS of this action is + // simply the last action we processed, we saved the LHS action index in + // the joining action. + size_t SelectRHSIdx = i - 1; + Value *SelectRHS = Actions[SelectRHSIdx].FoldResult; + size_t SelectLHSIdx = Actions[i].SelectLHSIdx; + Value *SelectLHS = Actions[SelectLHSIdx].FoldResult; + Inst = SelectInst::Create(cast(ActionOp1)->getCondition(), + SelectLHS, SelectRHS); + } + // If this is the last action to process, return it to the InstCombiner. + // Otherwise, we insert it before the Operator and record it so that we may + // use it as part of a joining action (i.e., a SelectInst). + if (e - i != 1) { + Inst->insertBefore(&I); + Actions[i].FoldResult = Inst; + } else + return Inst; + } + return nullptr; +} + +// X udiv 2^C -> X >> C +static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, + const BinaryOperator &I, InstCombiner &IC) { + Constant *C1 = getLogBase2(Op0->getType(), cast(Op1)); + if (!C1) + llvm_unreachable("Failed to constant fold udiv -> logbase2"); + BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1); + if (I.isExact()) + LShr->setIsExact(); + return LShr; +} + +// X udiv (C1 << N), where C1 is "1< X >> (N+C2) +// X udiv (zext (C1 << N)), where C1 is "1< X >> (N+C2) +static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, + InstCombiner &IC) { + Value *ShiftLeft; + if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) + ShiftLeft = Op1; + + Constant *CI; + Value *N; + if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) + llvm_unreachable("match should never fail here!"); + Constant *Log2Base = getLogBase2(N->getType(), CI); + if (!Log2Base) + llvm_unreachable("getLogBase2 should never fail here!"); + N = IC.Builder.CreateAdd(N, Log2Base); + if (Op1 != ShiftLeft) + N = IC.Builder.CreateZExt(N, Op1->getType()); + BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); + if (I.isExact()) + LShr->setIsExact(); + return LShr; +} + +// Recursively visits the possible right hand operands of a udiv +// instruction, seeing through select instructions, to determine if we can +// replace the udiv with something simpler. If we find that an operand is not +// able to simplify the udiv, we abort the entire transformation. +static size_t visitUDivOperand(Value *Op0, Value *Op1, + SmallVectorImpl &Actions, + unsigned Depth = 0) { + // Check to see if this is an unsigned division with an exact power of 2, + // if so, convert to a right shift. + if (match(Op1, m_Power2())) { + Actions.push_back(OperandFoldAction(foldUDivPow2Cst, Op1)); + return Actions.size(); + } + + // X udiv (C1 << N), where C1 is "1< X >> (N+C2) + if (match(Op1, m_Shl(m_Power2(), m_Value())) || + match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) { + Actions.push_back(OperandFoldAction(foldUDivShl, Op1)); + return Actions.size(); + } + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth++ == MaxDepth) + return 0; + + if (SelectInst *SI = dyn_cast(Op1)) + if (size_t LHSIdx = + visitUDivOperand(Op0, SI->getOperand(1), Actions, Depth)) + if (visitUDivOperand(Op0, SI->getOperand(2), Actions, Depth)) { + Actions.push_back(OperandFoldAction(nullptr, Op1, LHSIdx - 1)); + return Actions.size(); + } + + return 0; +} + // TODO: This is a specific form of a much more general pattern. // We could detect a select with any binop identity constant, or we // could use SimplifyBinOp to see if either arm of the select reduces. @@ -809,111 +946,6 @@ return nullptr; } -static const unsigned MaxDepth = 6; - -namespace { - -using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, - const BinaryOperator &I, - InstCombiner &IC); - -/// Used to maintain state for visitUDivOperand(). -struct UDivFoldAction { - /// Informs visitUDiv() how to fold this operand. This can be zero if this - /// action joins two actions together. - FoldUDivOperandCb FoldAction; - - /// Which operand to fold. - Value *OperandToFold; - - union { - /// The instruction returned when FoldAction is invoked. - Instruction *FoldResult; - - /// Stores the LHS action index if this action joins two actions together. - size_t SelectLHSIdx; - }; - - UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand) - : FoldAction(FA), OperandToFold(InputOperand), FoldResult(nullptr) {} - UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand, size_t SLHS) - : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {} -}; - -} // end anonymous namespace - -// X udiv 2^C -> X >> C -static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, - const BinaryOperator &I, InstCombiner &IC) { - Constant *C1 = getLogBase2(Op0->getType(), cast(Op1)); - if (!C1) - llvm_unreachable("Failed to constant fold udiv -> logbase2"); - BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1); - if (I.isExact()) - LShr->setIsExact(); - return LShr; -} - -// X udiv (C1 << N), where C1 is "1< X >> (N+C2) -// X udiv (zext (C1 << N)), where C1 is "1< X >> (N+C2) -static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, - InstCombiner &IC) { - Value *ShiftLeft; - if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) - ShiftLeft = Op1; - - Constant *CI; - Value *N; - if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N)))) - llvm_unreachable("match should never fail here!"); - Constant *Log2Base = getLogBase2(N->getType(), CI); - if (!Log2Base) - llvm_unreachable("getLogBase2 should never fail here!"); - N = IC.Builder.CreateAdd(N, Log2Base); - if (Op1 != ShiftLeft) - N = IC.Builder.CreateZExt(N, Op1->getType()); - BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); - if (I.isExact()) - LShr->setIsExact(); - return LShr; -} - -// Recursively visits the possible right hand operands of a udiv -// instruction, seeing through select instructions, to determine if we can -// replace the udiv with something simpler. If we find that an operand is not -// able to simplify the udiv, we abort the entire transformation. -static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, - SmallVectorImpl &Actions, - unsigned Depth = 0) { - // Check to see if this is an unsigned division with an exact power of 2, - // if so, convert to a right shift. - if (match(Op1, m_Power2())) { - Actions.push_back(UDivFoldAction(foldUDivPow2Cst, Op1)); - return Actions.size(); - } - - // X udiv (C1 << N), where C1 is "1< X >> (N+C2) - if (match(Op1, m_Shl(m_Power2(), m_Value())) || - match(Op1, m_ZExt(m_Shl(m_Power2(), m_Value())))) { - Actions.push_back(UDivFoldAction(foldUDivShl, Op1)); - return Actions.size(); - } - - // The remaining tests are all recursive, so bail out if we hit the limit. - if (Depth++ == MaxDepth) - return 0; - - if (SelectInst *SI = dyn_cast(Op1)) - if (size_t LHSIdx = - visitUDivOperand(Op0, SI->getOperand(1), I, Actions, Depth)) - if (visitUDivOperand(Op0, SI->getOperand(2), I, Actions, Depth)) { - Actions.push_back(UDivFoldAction(nullptr, Op1, LHSIdx - 1)); - return Actions.size(); - } - - return 0; -} - /// If we have zero-extended operands of an unsigned div or rem, we may be able /// to narrow the operation (sink the zext below the math). static Instruction *narrowUDivURem(BinaryOperator &I, @@ -1012,35 +1044,10 @@ } // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) - SmallVector UDivActions; - if (visitUDivOperand(Op0, Op1, I, UDivActions)) - for (unsigned i = 0, e = UDivActions.size(); i != e; ++i) { - FoldUDivOperandCb Action = UDivActions[i].FoldAction; - Value *ActionOp1 = UDivActions[i].OperandToFold; - Instruction *Inst; - if (Action) - Inst = Action(Op0, ActionOp1, I, *this); - else { - // This action joins two actions together. The RHS of this action is - // simply the last action we processed, we saved the LHS action index in - // the joining action. - size_t SelectRHSIdx = i - 1; - Value *SelectRHS = UDivActions[SelectRHSIdx].FoldResult; - size_t SelectLHSIdx = UDivActions[i].SelectLHSIdx; - Value *SelectLHS = UDivActions[SelectLHSIdx].FoldResult; - Inst = SelectInst::Create(cast(ActionOp1)->getCondition(), - SelectLHS, SelectRHS); - } - - // If this is the last action to process, return it to the InstCombiner. - // Otherwise, we insert it before the UDiv and record it so that we may - // use it as part of a joining action (i.e., a SelectInst). - if (e - i != 1) { - Inst->insertBefore(&I); - UDivActions[i].FoldResult = Inst; - } else - return Inst; - } + SmallVector UDivActions; + if (visitUDivOperand(Op0, Op1, UDivActions)) + if (Instruction *Inst = combineActions(Op0, I, UDivActions, *this)) + return Inst; return nullptr; } @@ -1419,7 +1426,7 @@ // -X srem Y --> -(X srem Y) Value *X, *Y; if (match(&I, m_SRem(m_OneUse(m_NSWSub(m_Zero(), m_Value(X))), m_Value(Y)))) - return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); + return BinaryOperator::CreateNSWNeg(Builder.CreateSRem(X, Y)); // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a urem. diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -607,3 +607,106 @@ %mul = mul i32 %sel, %y ret i32 %mul } + + +; TODO. 'select + mul' -> 'select + shl' for power of twos + +define i32 @shift_if_power2(i32 %x, i1 %cond) { +; CHECK-LABEL: @shift_if_power2( +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 16, i32 4 +; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %sel = select i1 %cond, i32 16, i32 4 + %r = mul i32 %sel, %x + ret i32 %r +} + +define i32 @shift_if_power2_double_select(i32 %x, i32 %y, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @shift_if_power2_double_select( +; CHECK-NEXT: [[SHL_RES:%.*]] = shl i32 8, [[Y:%.*]] +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i32 [[SHL_RES]], i32 1024 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND2:%.*]], i32 16, i32 [[SEL1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i32 [[SEL2]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %shl.res = shl i32 8, %y + %sel1 = select i1 %cond1, i32 %shl.res, i32 1024 + %sel2 = select i1 %cond2, i32 16, i32 %sel1 + %r = mul nuw i32 %x, %sel2 + ret i32 %r +} + +define i32 @shift_if_power2_zero(i32 %x, i1 %cond) { +; CHECK-LABEL: @shift_if_power2_zero( +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 1, i32 4 +; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %sel = select i1 %cond, i32 1, i32 4 + %r = mul i32 %sel, %x + ret i32 %r +} + +define <2 x i8> @shift_if_power2_vector(<2 x i8> %px, i1 %cond) { +; CHECK-LABEL: @shift_if_power2_vector( +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = mul <2 x i8> [[SEL]], [[PX:%.*]] +; CHECK-NEXT: ret <2 x i8> [[R]] +; + %sel = select i1 %cond, <2 x i8> , <2 x i8> + %r = mul <2 x i8> %px, %sel + ret <2 x i8> %r +} + +define i32 @shift_if_extra_use(i32 %x, i1 %cond) { +; CHECK-LABEL: @shift_if_extra_use( +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 4, i32 128 +; CHECK-NEXT: call void @use32(i32 [[SEL]]) +; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %sel = select i1 %cond, i32 4, i32 128 + call void @use32(i32 %sel) + %r = mul i32 %sel, %x + ret i32 %r +} + +; Negative tests for power2 + +define i32 @shift_if_one_not_power2(i32 %x, i1 %cond) { +; CHECK-LABEL: @shift_if_one_not_power2( +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], i32 5, i32 4 +; CHECK-NEXT: [[R:%.*]] = mul i32 [[SEL]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %sel = select i1 %cond, i32 5, i32 4 + %r = mul i32 %sel, %x + ret i32 %r +} + +define i32 @shift_if_not_power2_double_select(i32 %x, i32 %y, i1 %cond1, i1 %cond2) { +; CHECK-LABEL: @shift_if_not_power2_double_select( +; CHECK-NEXT: [[SHL_RES:%.*]] = shl i32 7, [[Y:%.*]] +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[COND1:%.*]], i32 [[SHL_RES]], i32 1024 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[COND2:%.*]], i32 16, i32 [[SEL1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i32 [[SEL2]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %shl.res = shl i32 7, %y + %sel1 = select i1 %cond1, i32 %shl.res, i32 1024 + %sel2 = select i1 %cond2, i32 16, i32 %sel1 + %r = mul nuw i32 %x, %sel2 + ret i32 %r +} + +define <2 x i8> @shift_if_one_not_power2_vector(<2 x i8> %px, i1 %cond) { +; CHECK-LABEL: @shift_if_one_not_power2_vector( +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND:%.*]], <2 x i8> , <2 x i8> +; CHECK-NEXT: [[R:%.*]] = mul <2 x i8> [[SEL]], [[PX:%.*]] +; CHECK-NEXT: ret <2 x i8> [[R]] +; + %sel = select i1 %cond, <2 x i8> , <2 x i8> + %r = mul <2 x i8> %px, %sel + ret <2 x i8> %r +}