Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -358,6 +358,7 @@ setTargetDAGCombine(ISD::STORE); setTargetDAGCombine(ISD::MUL); + setTargetDAGCombine(ISD::SRA); setTargetDAGCombine(ISD::SELECT); setTargetDAGCombine(ISD::VSELECT); @@ -6355,6 +6356,89 @@ return performIntegerAbsCombine(N, DAG); } +// If this is a sdiv x, pow2, return x and lg2. +static bool isSDiv(SDNode *N, SDValue &X, uint64_t &lg2) { + assert(N && N->getOpcode() == ISD::SRA && "Unexpected opcode."); + + EVT VT = N->getValueType(0); + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // Divide by pow2. + if (!isa(N1)) + return false; + + APInt Log2 = cast(N1)->getAPIntValue(); + if (Log2.isNegative()) + return false; + + lg2 = Log2.getZExtValue(); + + // Add (N0 < 0) ? abs - 1 : 0; + if (N0->getOpcode() != ISD::ADD) + return false; + + N1 = N0->getOperand(1); + if (N1->getOpcode() != ISD::SRL) + return false; + + N0 = N1->getOperand(0); + N1 = N1->getOperand(1); + + // Splat of the sign bit. + if (N0->getOpcode() != ISD::SRA) + return false; + + if (!isa(N1)) + return false; + + uint64_t C1 = cast(N1)->getAPIntValue().getZExtValue(); + if (C1 != (VT.getScalarSizeInBits() - lg2)) + return false; + + X = N0->getOperand(0); + N1 = N0->getOperand(1); + + if (!isa(N1)) + return false; + + uint64_t C2 = cast(N1)->getAPIntValue().getZExtValue(); + if (C2 != (VT.getScalarSizeInBits() - 1)) + return false; + + return true; +} + +static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI) { + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + + EVT VT = N->getValueType(0); + if (VT != MVT::i32 && VT != MVT::i64) + return SDValue(); + + SDValue X; + uint64_t lg2; + if (!isSDiv(N, X, lg2)) + return SDValue(); + + SDLoc DL(N); + + SDValue Zero = DAG.getConstant(0, VT); + SDValue Pow2MinusOne = DAG.getConstant((1 << lg2) - 1, VT); + + SDValue CCVal; + + // Add (N0 < 0) ? Pow2 - 1 : 0; + SDValue Cmp = getAArch64Cmp(X, Zero, ISD::SETLT, CCVal, DAG, DL); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, X, Pow2MinusOne); + SDValue CSel = DAG.getNode(AArch64ISD::CSEL, DL, VT, Add, X, CCVal, Cmp); + + // Divide by pow2. + return DAG.getNode(ISD::SRA, DL, VT, CSel, DAG.getConstant(lg2, MVT::i64)); +} + static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget *Subtarget) { @@ -7708,6 +7792,8 @@ return performAddSubLongCombine(N, DCI, DAG); case ISD::XOR: return performXorCombine(N, DAG, DCI, Subtarget); + case ISD::SRA: + return performSRACombine(N, DAG, DCI); case ISD::MUL: return performMulCombine(N, DAG, DCI, Subtarget); case ISD::SINT_TO_FP: Index: test/CodeGen/AArch64/sdivpow2.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/sdivpow2.ll @@ -0,0 +1,61 @@ +; RUN: llc -mtriple=arm64-linux-gnu -o - %s | FileCheck %s + +define i32 @test1(i32 %x) { +; CHECK-LABEL: test1 +; CHECK: add w8, w0, #7 +; CHECK: cmp w0, #0 +; CHECK: csel w8, w8, w0, lt +; CHECK: asr w0, w8, #3 + %div = sdiv i32 %x, 8 + ret i32 %div +} + +define i32 @test2(i32 %x) { +; CHECK-LABEL: test2 +; CHECK: add w8, w0, #7 +; CHECK: cmp w0, #0 +; CHECK: csel w8, w8, w0, lt +; CHECK: neg w0, w8, asr #3 + %div = sdiv i32 %x, -8 + ret i32 %div +} + +define i32 @test3(i32 %x) { +; CHECK-LABEL: test3 +; CHECK: add w8, w0, #31 +; CHECK: cmp w0, #0 +; CHECK: csel w8, w8, w0, lt +; CHECK: asr w0, w8, #5 + %div = sdiv i32 %x, 32 + ret i32 %div +} + +define i64 @test4(i64 %x) { +; CHECK-LABEL: test4 +; CHECK: add x8, x0, #7 +; CHECK: cmp x0, #0 +; CHECK: csel x8, x8, x0, lt +; CHECK: asr x0, x8, #3 + %div = sdiv i64 %x, 8 + ret i64 %div +} + +define i64 @test5(i64 %x) { +; CHECK-LABEL: test5 +; CHECK: add x8, x0, #7 +; CHECK: cmp x0, #0 +; CHECK: csel x8, x8, x0, lt +; CHECK: neg x0, x8, asr #3 + %div = sdiv i64 %x, -8 + ret i64 %div +} + +define i64 @test6(i64 %x) { +; CHECK-LABEL: test6 +; CHECK: add x8, x0, #63 +; CHECK: cmp x0, #0 +; CHECK: csel x8, x8, x0, lt +; CHECK: asr x0, x8, #6 + %div = sdiv i64 %x, 64 + ret i64 %div +}