Index: llvm/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/include/llvm/CodeGen/TargetLowering.h +++ llvm/include/llvm/CodeGen/TargetLowering.h @@ -4449,6 +4449,14 @@ SelectionDAG &DAG, SmallVectorImpl &Created) const; + /// Targets may override this function to provide custom SREM lowering for + /// power-of-2 denominators. If the target returns an empty SDValue, LLVM + /// assumes SREM is expensive and replaces it with a series of other integer + /// operations. + virtual SDValue BuildSREMPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl &Created) const; + /// Indicate whether this target prefers to combine FDIVs with the same /// divisor. If the transform should never be done, return zero. If the /// transform should be done, return the minimum number of divisor uses Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -418,7 +418,7 @@ SDValue visitMULFIX(SDNode *N); SDValue useDivRem(SDNode *N); SDValue visitSDIV(SDNode *N); - SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N); + SDValue visitSDIVLike(SDValue N0, SDValue N1, SDNode *N, bool &BuildRem); SDValue visitUDIV(SDNode *N); SDValue visitUDIVLike(SDValue N0, SDValue N1, SDNode *N); SDValue visitREM(SDNode *N); @@ -569,6 +569,7 @@ SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); + SDValue BuildSREMPow2(SDNode *N); SDValue BuildLogBase2(SDValue V, const SDLoc &DL); SDValue BuildDivEstimate(SDValue N, SDValue Op, SDNodeFlags Flags); SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags Flags); @@ -4283,7 +4284,8 @@ if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); - if (SDValue V = visitSDIVLike(N0, N1, N)) { + bool BuildRem = false; + if (SDValue V = visitSDIVLike(N0, N1, N, BuildRem)) { // If the corresponding remainder node exists, update its users with // (Dividend - (Quotient * Divisor). if (SDNode *RemNode = DAG.getNodeIfExists(ISD::SREM, N->getVTList(), @@ -4308,7 +4310,8 @@ return SDValue(); } -SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N) { +SDValue DAGCombiner::visitSDIVLike(SDValue N0, SDValue N1, SDNode *N, + bool &BuildRem) { SDLoc DL(N); EVT VT = N->getValueType(0); EVT CCVT = getSetCCResultType(VT); @@ -4331,6 +4334,13 @@ // better results in that case. The target-specific lowering should learn how // to handle exact sdivs efficiently. if (!N->getFlags().hasExact() && ISD::matchUnaryPredicate(N1, IsPowerOfTwo)) { + if (BuildRem) { + // Target-specific implementation of srem x, pow2. + if (SDValue Res = BuildSREMPow2(N)) + return Res; + } + + BuildRem = false; // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; @@ -4377,6 +4387,7 @@ return Res; } + BuildRem = false; // If integer divide is expensive and we satisfy the requirements, emit an // alternate sequence. Targets may check function attributes for size/speed // trade-offs. @@ -4546,19 +4557,25 @@ // combine will not return a DIVREM. Regardless, checking cheapness here // makes sense since the simplification results in fatter code. if (DAG.isKnownNeverZero(N1) && !TLI.isIntDivCheap(VT, Attr)) { - SDValue OptimizedDiv = - isSigned ? visitSDIVLike(N0, N1, N) : visitUDIVLike(N0, N1, N); - if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != N) { - // If the equivalent Div node also exists, update its users. - unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; - if (SDNode *DivNode = DAG.getNodeIfExists(DivOpcode, N->getVTList(), - { N0, N1 })) - CombineTo(DivNode, OptimizedDiv); - SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1); - SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); - AddToWorklist(OptimizedDiv.getNode()); - AddToWorklist(Mul.getNode()); - return Sub; + unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + // check if we can build faster implementation for srem only(no div pair) + bool BuildRem = + isSigned && !DAG.doesNodeExist(DivOpcode, N->getVTList(), {N0, N1}); + SDValue OptimizedDivOrRem = isSigned ? visitSDIVLike(N0, N1, N, BuildRem) + : visitUDIVLike(N0, N1, N); + if (OptimizedDivOrRem.getNode() && OptimizedDivOrRem.getNode() != N) { + if (!BuildRem) { + // If the equivalent Div node also exists, update its users. + if (SDNode *DivNode = + DAG.getNodeIfExists(DivOpcode, N->getVTList(), {N0, N1})) + CombineTo(DivNode, OptimizedDivOrRem); + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDivOrRem, N1); + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); + AddToWorklist(OptimizedDivOrRem.getNode()); + AddToWorklist(Mul.getNode()); + return Sub; + } + return OptimizedDivOrRem; } } @@ -23838,6 +23855,27 @@ return SDValue(); } +/// Given an ISD::SREM node expressing a remainder by constant power of 2, +/// return a DAG expression that will generate the same value by right shifting. +SDValue DAGCombiner::BuildSREMPow2(SDNode *N) { + ConstantSDNode *C = isConstOrConstSplat(N->getOperand(1)); + if (!C) + return SDValue(); + + // Avoid division by zero. + if (C->isZero()) + return SDValue(); + + SmallVector Built; + if (SDValue S = TLI.BuildSREMPow2(N, C->getAPIntValue(), DAG, Built)) { + for (SDNode *N : Built) + AddToWorklist(N); + return S; + } + + return SDValue(); +} + /// Determines the LogBase2 value for a non-null input value using the /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). SDValue DAGCombiner::BuildLogBase2(SDValue V, const SDLoc &DL) { Index: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5471,6 +5471,17 @@ return SDValue(); } +SDValue +TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl &Created) const { + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (TLI.isIntDivCheap(N->getValueType(0), Attr)) + return SDValue(N, 0); // Lower SREM as SREM + return SDValue(); +} + /// Given an ISD::SDIV node expressing a divide by constant, /// return a DAG expression to select that will generate the same value by /// multiplying by a magic number. Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1042,6 +1042,8 @@ SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, SmallVectorImpl &Created) const override; + SDValue BuildSREMPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, + SmallVectorImpl &Created) const override; SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, bool &UseOneConst, bool Reciprocal) const override; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13570,6 +13570,60 @@ return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), SRA); } +SDValue +AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor, + SelectionDAG &DAG, + SmallVectorImpl &Created) const { + AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes(); + if (isIntDivCheap(N->getValueType(0), Attr)) + return SDValue(N, 0); // Lower SREM as SREM + + EVT VT = N->getValueType(0); + + // For scalable and fixed types, mark them as cheap so we can handle it much + // later. This allows us to handle larger than legal types. + if (VT.isScalableVector() || Subtarget->useSVEForFixedLengthVectors()) + return SDValue(N, 0); + + // fold (srem X, pow2) + if ((VT != MVT::i32 && VT != MVT::i64) || + !(Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2())) + return SDValue(); + + unsigned Lg2 = Divisor.countTrailingZeros(); + if (Lg2 == 0 || Lg2 == (VT.getScalarSizeInBits() - 1)) + return SDValue(); + + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + SDValue Pow2MinusOne = DAG.getConstant((1ULL << Lg2) - 1, DL, VT); + SDValue Zero = DAG.getConstant(0, DL, VT); + SDValue CCVal, CSNeg; + if (Lg2 == 1) { + SDValue Cmp = getAArch64Cmp(N0, Zero, ISD::SETLT, CCVal, DAG, DL); + SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne); + CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, And, And, CCVal, Cmp); + + Created.push_back(Cmp.getNode()); + Created.push_back(And.getNode()); + } else { + SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC); + SDVTList VTs = DAG.getVTList(VT, MVT::i32); + + SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0); + SDValue AndPos = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne); + SDValue AndNeg = DAG.getNode(ISD::AND, DL, VT, Negs, Pow2MinusOne); + CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, AndPos, AndNeg, CCVal, + Negs.getValue(1)); + + Created.push_back(Negs.getNode()); + Created.push_back(AndPos.getNode()); + Created.push_back(AndNeg.getNode()); + } + + return CSNeg; +} + static bool IsSVECntIntrinsic(SDValue S) { switch(getIntrinsicID(S.getNode())) { default: Index: llvm/test/CodeGen/AArch64/srem-lkk.ll =================================================================== --- llvm/test/CodeGen/AArch64/srem-lkk.ll +++ llvm/test/CodeGen/AArch64/srem-lkk.ll @@ -95,11 +95,10 @@ define i32 @dont_fold_srem_power_of_two(i32 %x) { ; CHECK-LABEL: dont_fold_srem_power_of_two: ; CHECK: // %bb.0: -; CHECK-NEXT: add w8, w0, #63 -; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: csel w8, w8, w0, lt -; CHECK-NEXT: and w8, w8, #0xffffffc0 -; CHECK-NEXT: sub w0, w0, w8 +; CHECK-NEXT: negs w8, w0 +; CHECK-NEXT: and w9, w0, #0x3f +; CHECK-NEXT: and w8, w8, #0x3f +; CHECK-NEXT: csneg w0, w9, w8, mi ; CHECK-NEXT: ret %1 = srem i32 %x, 64 ret i32 %1 @@ -147,3 +146,26 @@ %1 = srem i64 %x, 98 ret i64 %1 } + +define i32 @fold_srem_two(i32 %x) { +; CHECK-LABEL: fold_srem_two: +; CHECK: // %bb.0: +; CHECK-NEXT: and w8, w0, #0x1 +; CHECK-NEXT: cmp w0, #0 +; CHECK-NEXT: cneg w0, w8, ge +; CHECK-NEXT: ret + %1 = srem i32 %x, 2 + ret i32 %1 +} + +define i64 @fold_srem_two_i64(i64 %x) { +; CHECK-LABEL: fold_srem_two_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: and x8, x0, #0x1 +; CHECK-NEXT: cmp x0, #0 +; CHECK-NEXT: cneg x0, x8, ge +; CHECK-NEXT: ret + %1 = srem i64 %x, 2 + ret i64 %1 +} + Index: llvm/test/CodeGen/AArch64/srem-seteq.ll =================================================================== --- llvm/test/CodeGen/AArch64/srem-seteq.ll +++ llvm/test/CodeGen/AArch64/srem-seteq.ll @@ -234,11 +234,11 @@ define i32 @test_srem_pow2(i32 %X) nounwind { ; CHECK-LABEL: test_srem_pow2: ; CHECK: // %bb.0: -; CHECK-NEXT: add w8, w0, #15 -; CHECK-NEXT: cmp w0, #0 -; CHECK-NEXT: csel w8, w8, w0, lt -; CHECK-NEXT: and w8, w8, #0xfffffff0 -; CHECK-NEXT: cmp w0, w8 +; CHECK-NEXT: negs w8, w0 +; CHECK-NEXT: and w9, w0, #0xf +; CHECK-NEXT: and w8, w8, #0xf +; CHECK-NEXT: csneg w8, w9, w8, mi +; CHECK-NEXT: cmp w8, #0 ; CHECK-NEXT: cset w0, eq ; CHECK-NEXT: ret %srem = srem i32 %X, 16 Index: llvm/test/CodeGen/AArch64/srem-vector-lkk.ll =================================================================== --- llvm/test/CodeGen/AArch64/srem-vector-lkk.ll +++ llvm/test/CodeGen/AArch64/srem-vector-lkk.ll @@ -159,36 +159,33 @@ ; CHECK-NEXT: smov w9, v0.h[1] ; CHECK-NEXT: smov w10, v0.h[0] ; CHECK-NEXT: mov w8, #37253 +; CHECK-NEXT: smov w12, v0.h[2] ; CHECK-NEXT: movk w8, #44150, lsl #16 -; CHECK-NEXT: add w11, w9, #31 -; CHECK-NEXT: cmp w9, #0 -; CHECK-NEXT: add w12, w10, #63 -; CHECK-NEXT: csel w11, w11, w9, lt -; CHECK-NEXT: cmp w10, #0 -; CHECK-NEXT: and w11, w11, #0xffffffe0 -; CHECK-NEXT: csel w12, w12, w10, lt -; CHECK-NEXT: sub w9, w9, w11 -; CHECK-NEXT: and w12, w12, #0xffffffc0 -; CHECK-NEXT: sub w10, w10, w12 -; CHECK-NEXT: smov w12, v0.h[3] -; CHECK-NEXT: fmov s1, w10 -; CHECK-NEXT: smov w10, v0.h[2] -; CHECK-NEXT: smull x8, w12, w8 -; CHECK-NEXT: mov v1.h[1], w9 +; CHECK-NEXT: negs w11, w9 +; CHECK-NEXT: and w9, w9, #0x1f +; CHECK-NEXT: and w11, w11, #0x1f +; CHECK-NEXT: csneg w9, w9, w11, mi +; CHECK-NEXT: negs w11, w10 +; CHECK-NEXT: and w10, w10, #0x3f +; CHECK-NEXT: and w11, w11, #0x3f +; CHECK-NEXT: csneg w10, w10, w11, mi +; CHECK-NEXT: smov w11, v0.h[3] +; CHECK-NEXT: fmov s0, w10 +; CHECK-NEXT: negs w10, w12 +; CHECK-NEXT: smull x8, w11, w8 +; CHECK-NEXT: and w10, w10, #0x7 ; CHECK-NEXT: lsr x8, x8, #32 -; CHECK-NEXT: add w9, w10, #7 -; CHECK-NEXT: cmp w10, #0 -; CHECK-NEXT: csel w9, w9, w10, lt -; CHECK-NEXT: add w8, w8, w12 -; CHECK-NEXT: and w9, w9, #0xfffffff8 -; CHECK-NEXT: sub w9, w10, w9 +; CHECK-NEXT: mov v0.h[1], w9 +; CHECK-NEXT: and w9, w12, #0x7 +; CHECK-NEXT: add w8, w8, w11 +; CHECK-NEXT: csneg w9, w9, w10, mi ; CHECK-NEXT: asr w10, w8, #6 ; CHECK-NEXT: add w8, w10, w8, lsr #31 ; CHECK-NEXT: mov w10, #95 -; CHECK-NEXT: mov v1.h[2], w9 -; CHECK-NEXT: msub w8, w8, w10, w12 -; CHECK-NEXT: mov v1.h[3], w8 -; CHECK-NEXT: fmov d0, d1 +; CHECK-NEXT: mov v0.h[2], w9 +; CHECK-NEXT: msub w8, w8, w10, w11 +; CHECK-NEXT: mov v0.h[3], w8 +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret %1 = srem <4 x i16> %x, ret <4 x i16> %1 @@ -245,27 +242,25 @@ ; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 ; CHECK-NEXT: smov w8, v0.h[2] ; CHECK-NEXT: mov w9, #17097 -; CHECK-NEXT: smov w10, v0.h[1] ; CHECK-NEXT: movk w9, #45590, lsl #16 -; CHECK-NEXT: mov w11, #32767 +; CHECK-NEXT: smov w10, v0.h[1] ; CHECK-NEXT: smov w12, v0.h[3] ; CHECK-NEXT: movi d1, #0000000000000000 +; CHECK-NEXT: mov w11, #23 ; CHECK-NEXT: smull x9, w8, w9 -; CHECK-NEXT: add w11, w10, w11 -; CHECK-NEXT: cmp w10, #0 ; CHECK-NEXT: lsr x9, x9, #32 -; CHECK-NEXT: csel w11, w11, w10, lt ; CHECK-NEXT: add w9, w9, w8 -; CHECK-NEXT: and w11, w11, #0xffff8000 ; CHECK-NEXT: asr w13, w9, #4 -; CHECK-NEXT: sub w10, w10, w11 -; CHECK-NEXT: mov w11, #47143 ; CHECK-NEXT: add w9, w13, w9, lsr #31 -; CHECK-NEXT: mov w13, #23 -; CHECK-NEXT: movk w11, #24749, lsl #16 +; CHECK-NEXT: negs w13, w10 +; CHECK-NEXT: and w10, w10, #0x7fff +; CHECK-NEXT: and w13, w13, #0x7fff +; CHECK-NEXT: csneg w10, w10, w13, mi +; CHECK-NEXT: mov w13, #47143 +; CHECK-NEXT: movk w13, #24749, lsl #16 +; CHECK-NEXT: msub w8, w9, w11, w8 +; CHECK-NEXT: smull x9, w12, w13 ; CHECK-NEXT: mov v1.h[1], w10 -; CHECK-NEXT: msub w8, w9, w13, w8 -; CHECK-NEXT: smull x9, w12, w11 ; CHECK-NEXT: lsr x10, x9, #63 ; CHECK-NEXT: asr x9, x9, #43 ; CHECK-NEXT: add w9, w9, w10