diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4479,6 +4479,14 @@ SelectionDAG &DAG, SmallVectorImpl &Created) const; + /// Targets may override this function to provide custom SREM lowering for + /// power-of-2 denominators. If the target returns an empty SDValue, LLVM + /// assumes SREM is expensive and replaces it with a series of other integer + /// operations. + virtual SDValue BuildSREMPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl &Created) const; + /// Indicate whether this target prefers to combine FDIVs with the same /// divisor. If the transform should never be done, return zero. If the /// transform should be done, return the minimum number of divisor uses diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -569,6 +569,8 @@ SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); + SDValue BuildSREMPow2(SDNode *N); + SDValue buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N); SDValue BuildLogBase2(SDValue V, const SDLoc &DL); SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags); SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags); @@ -4320,12 +4322,7 @@ return SDValue(); } -SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { - SDLoc DL(N); - EVT VT = N->getValueType(0); - EVT CCVT = getSetCCResultType(VT); - unsigned BitWidth = VT.getScalarSizeInBits(); - +static bool isDivisorPowerOfTwo(SDValue Divisor) { // Helper for determining whether a value is a power-2 constant scalar or a // vector of such elements. auto IsPowerOfTwo = [](ConstantSDNode *C) { @@ -4338,11 +4335,20 @@ return false; }; + return ISD::matchUnaryPredicate(Divisor, IsPowerOfTwo); +} + +SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + EVT CCVT = getSetCCResultType(VT); + unsigned BitWidth = VT.getScalarSizeInBits(); + // fold (sdiv X, pow2) -> simple ops after legalize // FIXME: We check for the exact bit here because the generic lowering gives // better results in that case. The target-specific lowering should learn how // to handle exact sdivs efficiently. - if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) { + if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1)) { // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; @@ -4498,6 +4504,16 @@ return SDValue(); } +SDValue DAGCombiner::buildOptimizedSREM(SDValue N0, SDValue N1, SDNode *N) { + if (!N->getFlags().hasExact() && isDivisorPowerOfTwo(N1) && + !DAG.doesNodeExist(ISD::SDIV, N->getVTList(), {N0, N1})) { + // Target-specific implementation of srem x, pow2. + if (SDValue Res = BuildSREMPow2(N)) + return Res; + } + return SDValue(); +} + // handles ISD::SREM and ISD::UREM SDValue DAGCombiner::visitREM(SDNode *N) { unsigned Opcode = N->getOpcode(); @@ -4558,6 +4574,12 @@ // combine will not return a DIVREM. Regardless, checking cheapness here // makes sense since the simplification results in fatter code. if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { + if (isSigned) { + // check if we can build faster implementation for srem + SDValue OptimizedRem = buildOptimizedSREM(N0, N1, N); + if (OptimizedRem.getNode()) + return OptimizedRem; + } SDValue OptimizedDiv = isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) { @@ -23876,6 +23898,27 @@ return SDValue(); } +/// Given an ISD::SREM node expressing a remainder by constant power of 2, +/// return a DAG expression that will generate the same value. +SDValue DAGCombiner::BuildSREMPow2(SDNode *N) { + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); + + // Avoid division by zero. + if (C->isZero()) + return SDValue(); + + SmallVector Built; + if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) { + for (SDNode *N : Built) + AddToWorklist(N); + return S; + } + + return SDValue(); +} + /// Determines the LogBase2 value for a non-null input value using the /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) { diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5560,6 +5560,17 @@ return SDValue(); } +SDValue +TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl &Created) const { + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.isIntDivCheap(N->getValueType(0), Attr)) + return SDValue(N, 0); // Lower SREM as SREM + return SDValue(); +} + /// Given an ISD::SDIV node expressing a divide by constant, /// return a DAG expression to select that will generate the same value by /// multiplying by a magic number. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1043,6 +1043,8 @@ SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, SmallVectorImpl &Created) const override; + SDValue BuildSREMPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, + SmallVectorImpl &Created) const override; SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, bool &UseOneConst, bool Reciprocal) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13567,6 +13567,60 @@ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); } +SDValue +AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl &Created) const { + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + if (isIntDivCheap(N->getValueType(0), Attr)) + return SDValue(N, 0); // Lower SREM as SREM + + EVT VT = N->getValueType(0); + + // For scalable and fixed types, mark them as cheap so we can handle it much + // later. This allows us to handle larger than legal types. + if (VT.isScalableVector() || Subtarget->useSVEForFixedLengthVectors()) + return SDValue(N, 0); + + // fold (srem X, pow2) + if ((VT != MVT::i32 && VT != MVT::i64) || + !(Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2())) + return SDValue(); + + unsigned Lg2 = Divisor.countTrailingZeros(); + if (Lg2 == 0) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + SDValue Pow2MinusOne = DAG.getConstant((1ULL << Lg2) - 1, DL, VT); + SDValue Zero = DAG.getConstant(0, DL, VT); + SDValue CCVal, CSNeg; + if (Lg2 == 1) { + SDValue Cmp = getAArch64Cmp(N0, Zero, ISD::SETGE, CCVal, DAG, DL); + SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne); + CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, And, And, CCVal, Cmp); + + Created.push_back(Cmp.getNode()); + Created.push_back(And.getNode()); + } else { + SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC); + SDVTList VTs = DAG.getVTList(VT, MVT::i32); + + SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0); + SDValue AndPos = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne); + SDValue AndNeg = DAG.getNode(ISD::AND, DL, VT, Negs, Pow2MinusOne); + CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, AndPos, AndNeg, CCVal, + Negs.getValue(1)); + + Created.push_back(Negs.getNode()); + Created.push_back(AndPos.getNode()); + Created.push_back(AndNeg.getNode()); + } + + return CSNeg; +} + static bool IsSVECntIntrinsic(SDValue S) { switch(getIntrinsicID(S.getNode())) { default: diff --git a/llvm/test/CodeGen/AArch64/srem-pow2.ll b/llvm/test/CodeGen/AArch64/srem-pow2.ll --- a/llvm/test/CodeGen/AArch64/srem-pow2.ll +++ b/llvm/test/CodeGen/AArch64/srem-pow2.ll @@ -43,10 +43,9 @@ define i32 @fold_srem_2_i64(i32 %x) { ; CHECK-LABEL: fold_srem_2_i64: ; CHECK: // %bb.0: +; CHECK-NEXT: and w8, w0, #0x1 ; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: cinc w8, w0, lt -; CHECK-NEXT: and w8, w8, #0xfffffffe -; CHECK-NEXT: sub w0, w0, w8 +; CHECK-NEXT: cneg w0, w8, lt ; CHECK-NEXT: ret %1 = srem i32 %x, 2 ret i32 %1 @@ -55,10 +54,9 @@ define i64 @fold_srem_2_i32(i64 %x) { ; CHECK-LABEL: fold_srem_2_i32: ; CHECK: // %bb.0: +; CHECK-NEXT: and x8, x0, #0x1 ; CHECK-NEXT: cmp x0, #0 -; CHECK-NEXT: cinc x8, x0, lt -; CHECK-NEXT: and x8, x8, #0xfffffffffffffffe -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: cneg x0, x8, lt ; CHECK-NEXT: ret %1 = srem i64 %x, 2 ret i64 %1 @@ -80,11 +78,10 @@ define i32 @fold_srem_pow2_i32(i32 %x) { ; CHECK-LABEL: fold_srem_pow2_i32: ; CHECK: // %bb.0: -; CHECK-NEXT: add w8, w0, #63 -; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: csel w8, w8, w0, lt -; CHECK-NEXT: and w8, w8, #0xffffffc0 -; CHECK-NEXT: sub w0, w0, w8 +; CHECK-NEXT: negs w8, w0 +; CHECK-NEXT: and w9, w0, #0x3f +; CHECK-NEXT: and w8, w8, #0x3f +; CHECK-NEXT: csneg w0, w9, w8, mi ; CHECK-NEXT: ret %1 = srem i32 %x, 64 ret i32 %1 @@ -93,11 +90,10 @@ define i64 @fold_srem_pow2_i64(i64 %x) { ; CHECK-LABEL: fold_srem_pow2_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: add x8, x0, #63 -; CHECK-NEXT: cmp x0, #0 -; CHECK-NEXT: csel x8, x8, x0, lt -; CHECK-NEXT: and x8, x8, #0xffffffffffffffc0 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: negs x8, x0 +; CHECK-NEXT: and x9, x0, #0x3f +; CHECK-NEXT: and x8, x8, #0x3f +; CHECK-NEXT: csneg x0, x9, x8, mi ; CHECK-NEXT: ret %1 = srem i64 %x, 64 ret i64 %1 @@ -119,12 +115,10 @@ define i32 @fold_srem_smax_i32(i32 %x) { ; CHECK-LABEL: fold_srem_smax_i32: ; CHECK: // %bb.0: -; CHECK-NEXT: mov w8, #2147483647 -; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: add w8, w0, w8 -; CHECK-NEXT: csel w8, w8, w0, lt -; CHECK-NEXT: and w8, w8, #0x80000000 -; CHECK-NEXT: add w0, w0, w8 +; CHECK-NEXT: negs w8, w0 +; CHECK-NEXT: and w9, w0, #0x7fffffff +; CHECK-NEXT: and w8, w8, #0x7fffffff +; CHECK-NEXT: csneg w0, w9, w8, mi ; CHECK-NEXT: ret %1 = srem i32 %x, 2147483648 ret i32 %1 @@ -133,12 +127,10 @@ define i64 @fold_srem_smax_i64(i64 %x) { ; CHECK-LABEL: fold_srem_smax_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: mov x8, #9223372036854775807 -; CHECK-NEXT: cmp x0, #0 -; CHECK-NEXT: add x8, x0, x8 -; CHECK-NEXT: csel x8, x8, x0, lt -; CHECK-NEXT: and x8, x8, #0x8000000000000000 -; CHECK-NEXT: add x0, x0, x8 +; CHECK-NEXT: negs x8, x0 +; CHECK-NEXT: and x9, x0, #0x7fffffffffffffff +; CHECK-NEXT: and x8, x8, #0x7fffffffffffffff +; CHECK-NEXT: csneg x0, x9, x8, mi ; CHECK-NEXT: ret %1 = srem i64 %x, -9223372036854775808 ret i64 %1 diff --git a/llvm/test/CodeGen/AArch64/srem-seteq.ll b/llvm/test/CodeGen/AArch64/srem-seteq.ll --- a/llvm/test/CodeGen/AArch64/srem-seteq.ll +++ b/llvm/test/CodeGen/AArch64/srem-seteq.ll @@ -234,11 +234,11 @@ define i32 @test_srem_pow2(i32 %X) nounwind { ; CHECK-LABEL: test_srem_pow2: ; CHECK: // %bb.0: -; CHECK-NEXT: add w8, w0, #15 -; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: csel w8, w8, w0, lt -; CHECK-NEXT: and w8, w8, #0xfffffff0 -; CHECK-NEXT: cmp w0, w8 +; CHECK-NEXT: negs w8, w0 +; CHECK-NEXT: and w9, w0, #0xf +; CHECK-NEXT: and w8, w8, #0xf +; CHECK-NEXT: csneg w8, w9, w8, mi +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cset w0, eq ; CHECK-NEXT: ret %srem = srem i32 %X, 16 @@ -251,12 +251,11 @@ define i32 @test_srem_int_min(i32 %X) nounwind { ; CHECK-LABEL: test_srem_int_min: ; CHECK: // %bb.0: -; CHECK-NEXT: mov w8, #2147483647 -; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: add w8, w0, w8 -; CHECK-NEXT: csel w8, w8, w0, lt -; CHECK-NEXT: and w8, w8, #0x80000000 -; CHECK-NEXT: cmn w0, w8 +; CHECK-NEXT: negs w8, w0 +; CHECK-NEXT: and w9, w0, #0x7fffffff +; CHECK-NEXT: and w8, w8, #0x7fffffff +; CHECK-NEXT: csneg w8, w9, w8, mi +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cset w0, eq ; CHECK-NEXT: ret %srem = srem i32 %X, 2147483648 diff --git a/llvm/test/CodeGen/AArch64/srem-vector-lkk.ll b/llvm/test/CodeGen/AArch64/srem-vector-lkk.ll --- a/llvm/test/CodeGen/AArch64/srem-vector-lkk.ll +++ b/llvm/test/CodeGen/AArch64/srem-vector-lkk.ll @@ -159,36 +159,33 @@ ; CHECK-NEXT: smov w9, v0.h[1] ; CHECK-NEXT: smov w10, v0.h[0] ; CHECK-NEXT: mov w8, #37253 +; CHECK-NEXT: smov w12, v0.h[2] ; CHECK-NEXT: movk w8, #44150, lsl #16 -; CHECK-NEXT: add w11, w9, #31 -; CHECK-NEXT: cmp w9, #0 -; CHECK-NEXT: add w12, w10, #63 -; CHECK-NEXT: csel w11, w11, w9, lt -; CHECK-NEXT: cmp w10, #0 -; CHECK-NEXT: and w11, w11, #0xffffffe0 -; CHECK-NEXT: csel w12, w12, w10, lt -; CHECK-NEXT: sub w9, w9, w11 -; CHECK-NEXT: and w12, w12, #0xffffffc0 -; CHECK-NEXT: sub w10, w10, w12 -; CHECK-NEXT: smov w12, v0.h[3] -; CHECK-NEXT: fmov s1, w10 -; CHECK-NEXT: smov w10, v0.h[2] -; CHECK-NEXT: smull x8, w12, w8 -; CHECK-NEXT: mov v1.h[1], w9 +; CHECK-NEXT: negs w11, w9 +; CHECK-NEXT: and w9, w9, #0x1f +; CHECK-NEXT: and w11, w11, #0x1f +; CHECK-NEXT: csneg w9, w9, w11, mi +; CHECK-NEXT: negs w11, w10 +; CHECK-NEXT: and w10, w10, #0x3f +; CHECK-NEXT: and w11, w11, #0x3f +; CHECK-NEXT: csneg w10, w10, w11, mi +; CHECK-NEXT: smov w11, v0.h[3] +; CHECK-NEXT: fmov s0, w10 +; CHECK-NEXT: negs w10, w12 +; CHECK-NEXT: smull x8, w11, w8 +; CHECK-NEXT: and w10, w10, #0x7 ; CHECK-NEXT: lsr x8, x8, #32 -; CHECK-NEXT: add w9, w10, #7 -; CHECK-NEXT: cmp w10, #0 -; CHECK-NEXT: csel w9, w9, w10, lt -; CHECK-NEXT: add w8, w8, w12 -; CHECK-NEXT: and w9, w9, #0xfffffff8 -; CHECK-NEXT: sub w9, w10, w9 +; CHECK-NEXT: mov v0.h[1], w9 +; CHECK-NEXT: and w9, w12, #0x7 +; CHECK-NEXT: add w8, w8, w11 +; CHECK-NEXT: csneg w9, w9, w10, mi ; CHECK-NEXT: asr w10, w8, #6 ; CHECK-NEXT: add w8, w10, w8, lsr #31 ; CHECK-NEXT: mov w10, #95 -; CHECK-NEXT: mov v1.h[2], w9 -; CHECK-NEXT: msub w8, w8, w10, w12 -; CHECK-NEXT: mov v1.h[3], w8 -; CHECK-NEXT: fmov d0, d1 +; CHECK-NEXT: mov v0.h[2], w9 +; CHECK-NEXT: msub w8, w8, w10, w11 +; CHECK-NEXT: mov v0.h[3], w8 +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret %1 = srem <4 x i16> %x, ret <4 x i16> %1 @@ -245,27 +242,25 @@ ; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 ; CHECK-NEXT: smov w8, v0.h[2] ; CHECK-NEXT: mov w9, #17097 -; CHECK-NEXT: smov w10, v0.h[1] ; CHECK-NEXT: movk w9, #45590, lsl #16 -; CHECK-NEXT: mov w11, #32767 +; CHECK-NEXT: smov w10, v0.h[1] ; CHECK-NEXT: smov w12, v0.h[3] ; CHECK-NEXT: movi d1, #0000000000000000 +; CHECK-NEXT: mov w11, #23 ; CHECK-NEXT: smull x9, w8, w9 -; CHECK-NEXT: add w11, w10, w11 -; CHECK-NEXT: cmp w10, #0 ; CHECK-NEXT: lsr x9, x9, #32 -; CHECK-NEXT: csel w11, w11, w10, lt ; CHECK-NEXT: add w9, w9, w8 -; CHECK-NEXT: and w11, w11, #0xffff8000 ; CHECK-NEXT: asr w13, w9, #4 -; CHECK-NEXT: sub w10, w10, w11 -; CHECK-NEXT: mov w11, #47143 ; CHECK-NEXT: add w9, w13, w9, lsr #31 -; CHECK-NEXT: mov w13, #23 -; CHECK-NEXT: movk w11, #24749, lsl #16 +; CHECK-NEXT: negs w13, w10 +; CHECK-NEXT: and w10, w10, #0x7fff +; CHECK-NEXT: and w13, w13, #0x7fff +; CHECK-NEXT: csneg w10, w10, w13, mi +; CHECK-NEXT: mov w13, #47143 +; CHECK-NEXT: movk w13, #24749, lsl #16 +; CHECK-NEXT: msub w8, w9, w11, w8 +; CHECK-NEXT: smull x9, w12, w13 ; CHECK-NEXT: mov v1.h[1], w10 -; CHECK-NEXT: msub w8, w9, w13, w8 -; CHECK-NEXT: smull x9, w12, w11 ; CHECK-NEXT: lsr x10, x9, #63 ; CHECK-NEXT: asr x9, x9, #43 ; CHECK-NEXT: add w9, w9, w10