diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -188,10 +188,12 @@ Value *simplifyMulInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); /// Given operands for an SDiv, fold the result or return null. -Value *simplifySDivInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); +Value *simplifySDivInst(Value *LHS, Value *RHS, bool IsExact, + const SimplifyQuery &Q); /// Given operands for a UDiv, fold the result or return null. -Value *simplifyUDivInst(Value *LHS, Value *RHS, const SimplifyQuery &Q); +Value *simplifyUDivInst(Value *LHS, Value *RHS, bool IsExact, + const SimplifyQuery &Q); /// Given operands for an FDiv, fold the result or return null. Value * diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -1143,13 +1143,24 @@ /// These are simplifications common to SDiv and UDiv. static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, - const SimplifyQuery &Q, unsigned MaxRecurse) { + bool IsExact, const SimplifyQuery &Q, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q, MaxRecurse)) return V; + // If this is an exact divide by a constant, then the dividend (Op0) must have + // at least as many trailing zeros as the divisor to divide evenly. If it has + // less trailing zeros, then the result must be poison. + const APInt *DivC; + if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countTrailingZeros()) { + KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (KnownOp0.countMaxTrailingZeros() < DivC->countTrailingZeros()) + return PoisonValue::get(Op0->getType()); + } + bool IsSigned = Opcode == Instruction::SDiv; // (X rem Y) / Y -> 0 @@ -1230,28 +1241,30 @@ /// Given operands for an SDiv, see if we can fold the result. /// If not, this returns null. -static Value *simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, - unsigned MaxRecurse) { +static Value *simplifySDivInst(Value *Op0, Value *Op1, bool IsExact, + const SimplifyQuery &Q, unsigned MaxRecurse) { // If two operands are negated and no signed overflow, return -1. if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) return Constant::getAllOnesValue(Op0->getType()); - return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse); + return simplifyDiv(Instruction::SDiv, Op0, Op1, IsExact, Q, MaxRecurse); } -Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::simplifySDivInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, bool IsExact, + const SimplifyQuery &Q) { + return ::simplifySDivInst(Op0, Op1, IsExact, Q, RecursionLimit); } /// Given operands for a UDiv, see if we can fold the result. /// If not, this returns null. -static Value *simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, - unsigned MaxRecurse) { - return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse); +static Value *simplifyUDivInst(Value *Op0, Value *Op1, bool IsExact, + const SimplifyQuery &Q, unsigned MaxRecurse) { + return simplifyDiv(Instruction::UDiv, Op0, Op1, IsExact, Q, MaxRecurse); } -Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { - return ::simplifyUDivInst(Op0, Op1, Q, RecursionLimit); +Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, bool IsExact, + const SimplifyQuery &Q) { + return ::simplifyUDivInst(Op0, Op1, IsExact, Q, RecursionLimit); } /// Given operands for an SRem, see if we can fold the result. @@ -1405,6 +1418,7 @@ return IsExact ? Op0 : Constant::getNullValue(Op0->getType()); // The low bit cannot be shifted out of an exact shift if it is set. + // TODO: Generalize by counting trailing zeros (see fold for exact division). if (IsExact) { KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); @@ -5678,9 +5692,9 @@ case Instruction::Mul: return simplifyMulInst(LHS, RHS, Q, MaxRecurse); case Instruction::SDiv: - return simplifySDivInst(LHS, RHS, Q, MaxRecurse); + return simplifySDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse); case Instruction::UDiv: - return simplifyUDivInst(LHS, RHS, Q, MaxRecurse); + return simplifyUDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse); case Instruction::SRem: return simplifySRemInst(LHS, RHS, Q, MaxRecurse); case Instruction::URem: @@ -6553,9 +6567,11 @@ case Instruction::Mul: return simplifyMulInst(NewOps[0], NewOps[1], Q); case Instruction::SDiv: - return simplifySDivInst(NewOps[0], NewOps[1], Q); + return simplifySDivInst(NewOps[0], NewOps[1], + Q.IIQ.isExact(cast(I)), Q); case Instruction::UDiv: - return simplifyUDivInst(NewOps[0], NewOps[1], Q); + return simplifyUDivInst(NewOps[0], NewOps[1], + Q.IIQ.isExact(cast(I)), Q); case Instruction::FDiv: return simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); case Instruction::SRem: diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1207,7 +1207,7 @@ } Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { - if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1287,7 +1287,7 @@ } Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { - if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); diff --git a/llvm/test/Transforms/InstCombine/udiv-simplify.ll b/llvm/test/Transforms/InstCombine/udiv-simplify.ll --- a/llvm/test/Transforms/InstCombine/udiv-simplify.ll +++ b/llvm/test/Transforms/InstCombine/udiv-simplify.ll @@ -95,13 +95,11 @@ ret i8 %u } -; TODO: This can't divide evenly, so it is poison. +; This can't divide evenly, so it is poison. define i8 @udiv_exact_demanded_low_bits_set(i8 %a) { ; CHECK-LABEL: @udiv_exact_demanded_low_bits_set( -; CHECK-NEXT: [[O:%.*]] = or i8 [[A:%.*]], 3 -; CHECK-NEXT: [[U:%.*]] = udiv exact i8 [[O]], 12 -; CHECK-NEXT: ret i8 [[U]] +; CHECK-NEXT: ret i8 poison ; %o = or i8 %a, 3 %u = udiv exact i8 %o, 12 diff --git a/llvm/test/Transforms/InstSimplify/div.ll b/llvm/test/Transforms/InstSimplify/div.ll --- a/llvm/test/Transforms/InstSimplify/div.ll +++ b/llvm/test/Transforms/InstSimplify/div.ll @@ -333,17 +333,19 @@ ret i1 %rem } +; Can't divide evenly, so create poison. + define i8 @sdiv_exact_trailing_zeros(i8 %x) { ; CHECK-LABEL: @sdiv_exact_trailing_zeros( -; CHECK-NEXT: [[O:%.*]] = or i8 [[X:%.*]], 1 -; CHECK-NEXT: [[R:%.*]] = sdiv exact i8 [[O]], -42 -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: ret i8 poison ; %o = or i8 %x, 1 ; odd number %r = sdiv exact i8 %o, -42 ; can't divide exactly ret i8 %r } +; Negative test - could divide evenly. + define i8 @sdiv_exact_trailing_zeros_eq(i8 %x) { ; CHECK-LABEL: @sdiv_exact_trailing_zeros_eq( ; CHECK-NEXT: [[O:%.*]] = or i8 [[X:%.*]], 2 @@ -355,6 +357,8 @@ ret i8 %r } +; Negative test - must be exact div. + define i8 @sdiv_trailing_zeros(i8 %x) { ; CHECK-LABEL: @sdiv_trailing_zeros( ; CHECK-NEXT: [[O:%.*]] = or i8 [[X:%.*]], 1 @@ -366,17 +370,32 @@ ret i8 %r } +; TODO: Match non-splat vector constants. + +define <2 x i8> @sdiv_exact_trailing_zeros_nonuniform_vector(<2 x i8> %x) { +; CHECK-LABEL: @sdiv_exact_trailing_zeros_nonuniform_vector( +; CHECK-NEXT: [[O:%.*]] = or <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = sdiv exact <2 x i8> [[O]], +; CHECK-NEXT: ret <2 x i8> [[R]] +; + %o = or <2 x i8> %x, + %r = sdiv exact <2 x i8> %o, + ret <2 x i8> %r +} + +; Can't divide evenly, so create poison. + define <2 x i8> @udiv_exact_trailing_zeros(<2 x i8> %x) { ; CHECK-LABEL: @udiv_exact_trailing_zeros( -; CHECK-NEXT: [[O:%.*]] = or <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[R:%.*]] = udiv exact <2 x i8> [[O]], -; CHECK-NEXT: ret <2 x i8> [[R]] +; CHECK-NEXT: ret <2 x i8> poison ; %o = or <2 x i8> %x, %r = udiv exact <2 x i8> %o, ; can't divide exactly ret <2 x i8> %r } +; Negative test - could divide evenly. + define <2 x i8> @udiv_exact_trailing_zeros_eq(<2 x i8> %x) { ; CHECK-LABEL: @udiv_exact_trailing_zeros_eq( ; CHECK-NEXT: [[O:%.*]] = or <2 x i8> [[X:%.*]], @@ -388,6 +407,8 @@ ret <2 x i8> %r } +; Negative test - must be exact div. + define i8 @udiv_trailing_zeros(i8 %x) { ; CHECK-LABEL: @udiv_trailing_zeros( ; CHECK-NEXT: [[O:%.*]] = or i8 [[X:%.*]], 1 @@ -399,4 +420,17 @@ ret i8 %r } +; Negative test - only the first element is poison + +define <2 x i8> @udiv_exact_trailing_zeros_nonuniform_vector(<2 x i8> %x) { +; CHECK-LABEL: @udiv_exact_trailing_zeros_nonuniform_vector( +; CHECK-NEXT: [[O:%.*]] = or <2 x i8> [[X:%.*]], +; CHECK-NEXT: [[R:%.*]] = udiv exact <2 x i8> [[O]], +; CHECK-NEXT: ret <2 x i8> [[R]] +; + %o = or <2 x i8> %x, + %r = udiv exact <2 x i8> %o, + ret <2 x i8> %r +} + !0 = !{i32 0, i32 3}