diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4421,6 +4421,10 @@ /// only the first Count elements of the vector are used. SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const; + /// Expand an SREM or UREM using SDIV/UDIV or SDIVREM/UDIVREM, if legal. + /// Returns true if the expansion was successful. + bool expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const; + //===--------------------------------------------------------------------===// // Instruction Emitting Hooks // diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -3343,26 +3343,10 @@ break; } case ISD::UREM: - case ISD::SREM: { - EVT VT = Node->getValueType(0); - bool isSigned = Node->getOpcode() == ISD::SREM; - unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV; - unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; - Tmp2 = Node->getOperand(0); - Tmp3 = Node->getOperand(1); - if (TLI.isOperationLegalOrCustom(DivRemOpc, VT)) { - SDVTList VTs = DAG.getVTList(VT, VT); - Tmp1 = DAG.getNode(DivRemOpc, dl, VTs, Tmp2, Tmp3).getValue(1); - Results.push_back(Tmp1); - } else if (TLI.isOperationLegalOrCustom(DivOpc, VT)) { - // X % Y -> X-X/Y*Y - Tmp1 = DAG.getNode(DivOpc, dl, VT, Tmp2, Tmp3); - Tmp1 = DAG.getNode(ISD::MUL, dl, VT, Tmp1, Tmp3); - Tmp1 = DAG.getNode(ISD::SUB, dl, VT, Tmp2, Tmp1); + case ISD::SREM: + if (TLI.expandREM(Node, Tmp1, DAG)) Results.push_back(Tmp1); - } break; - } case ISD::UDIV: case ISD::SDIV: { bool isSigned = Node->getOpcode() == ISD::SDIV; diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -145,6 +145,7 @@ void ExpandFixedPointDiv(SDNode *Node, SmallVectorImpl &Results); SDValue ExpandStrictFPOp(SDNode *Node); void ExpandStrictFPOp(SDNode *Node, SmallVectorImpl &Results); + void ExpandREM(SDNode *Node, SmallVectorImpl &Results); void UnrollStrictFPOp(SDNode *Node, SmallVectorImpl &Results); @@ -867,6 +868,10 @@ case ISD::VECREDUCE_FMIN: Results.push_back(TLI.expandVecReduce(Node, DAG)); return; + case ISD::SREM: + case ISD::UREM: + ExpandREM(Node, Results); + return; } Results.push_back(DAG.UnrollVectorOp(Node)); @@ -1353,6 +1358,17 @@ UnrollStrictFPOp(Node, Results); } +void VectorLegalizer::ExpandREM(SDNode *Node, + SmallVectorImpl &Results) { + assert((Node->getOpcode() == ISD::SREM || Node->getOpcode() == ISD::UREM) && + "Expected REM node"); + + SDValue Result; + if (!TLI.expandREM(Node, Result, DAG)) + Result = DAG.UnrollVectorOp(Node); + Results.push_back(Result); +} + void VectorLegalizer::UnrollStrictFPOp(SDNode *Node, SmallVectorImpl &Results) { EVT VT = Node->getValueType(0); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -7823,3 +7823,26 @@ Res = DAG.getNode(ISD::ANY_EXTEND, dl, Node->getValueType(0), Res); return Res; } + +bool TargetLowering::expandREM(SDNode *Node, SDValue &Result, + SelectionDAG &DAG) const { + EVT VT = Node->getValueType(0); + SDLoc dl(Node); + bool isSigned = Node->getOpcode() == ISD::SREM; + unsigned DivOpc = isSigned ? ISD::SDIV : ISD::UDIV; + unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; + SDValue Dividend = Node->getOperand(0); + SDValue Divisor = Node->getOperand(1); + if (isOperationLegalOrCustom(DivRemOpc, VT)) { + SDVTList VTs = DAG.getVTList(VT, VT); + Result = DAG.getNode(DivRemOpc, dl, VTs, Dividend, Divisor).getValue(1); + return true; + } else if (isOperationLegalOrCustom(DivOpc, VT)) { + // X % Y -> X-X/Y*Y + SDValue Divide = DAG.getNode(DivOpc, dl, VT, Dividend, Divisor); + SDValue Mul = DAG.getNode(ISD::MUL, dl, VT, Divide, Divisor); + Result = DAG.getNode(ISD::SUB, dl, VT, Dividend, Mul); + return true; + } + return false; +} diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -199,6 +199,10 @@ setOperationAction(ISD::UADDSAT, VT, Legal); setOperationAction(ISD::SSUBSAT, VT, Legal); setOperationAction(ISD::USUBSAT, VT, Legal); + setOperationAction(ISD::UREM, VT, Expand); + setOperationAction(ISD::SREM, VT, Expand); + setOperationAction(ISD::SDIVREM, VT, Expand); + setOperationAction(ISD::UDIVREM, VT, Expand); } for (auto VT : diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -443,7 +443,7 @@ setOperationAction(ISD::UREM, VT, Expand); setOperationAction(ISD::SMUL_LOHI, VT, Expand); setOperationAction(ISD::UMUL_LOHI, VT, Expand); - setOperationAction(ISD::SDIVREM, VT, Custom); + setOperationAction(ISD::SDIVREM, VT, Expand); setOperationAction(ISD::UDIVREM, VT, Expand); setOperationAction(ISD::SELECT, VT, Expand); setOperationAction(ISD::VSELECT, VT, Expand); diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -210,6 +210,8 @@ setOperationAction(ISD::SREM, VT, Expand); setOperationAction(ISD::UREM, VT, Expand); setOperationAction(ISD::FREM, VT, Expand); + setOperationAction(ISD::SDIVREM, VT, Expand); + setOperationAction(ISD::UDIVREM, VT, Expand); if (!VT.isFloatingPoint() && VT != MVT::v2i64 && VT != MVT::v1i64) @@ -284,6 +286,8 @@ setOperationAction(ISD::SDIV, VT, Expand); setOperationAction(ISD::UREM, VT, Expand); setOperationAction(ISD::SREM, VT, Expand); + setOperationAction(ISD::UDIVREM, VT, Expand); + setOperationAction(ISD::SDIVREM, VT, Expand); setOperationAction(ISD::CTPOP, VT, Expand); // Vector reductions diff --git a/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll b/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll --- a/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll +++ b/llvm/test/CodeGen/AArch64/llvm-ir-to-intrinsic.ll @@ -60,6 +60,36 @@ } ; +; SREM +; + +define @srem_i32( %a, %b) { +; CHECK-LABEL: srem_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: sdiv z2.s, p0/m, z2.s, z1.s +; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s +; CHECK-NEXT: sub z0.s, z0.s, z2.s +; CHECK-NEXT: ret + %div = srem %a, %b + ret %div +} + +define @srem_i64( %a, %b) { +; CHECK-LABEL: srem_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d +; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d +; CHECK-NEXT: sub z0.d, z0.d, z2.d +; CHECK-NEXT: ret + %div = srem %a, %b + ret %div +} + +; ; UDIV ; @@ -117,6 +147,37 @@ ret %div } + +; +; UREM +; + +define @urem_i32( %a, %b) { +; CHECK-LABEL: urem_i32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: udiv z2.s, p0/m, z2.s, z1.s +; CHECK-NEXT: mul z2.s, p0/m, z2.s, z1.s +; CHECK-NEXT: sub z0.s, z0.s, z2.s +; CHECK-NEXT: ret + %div = urem %a, %b + ret %div +} + +define @urem_i64( %a, %b) { +; CHECK-LABEL: urem_i64: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z2.d, z0.d +; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d +; CHECK-NEXT: mul z2.d, p0/m, z2.d, z1.d +; CHECK-NEXT: sub z0.d, z0.d, z2.d +; CHECK-NEXT: ret + %div = urem %a, %b + ret %div +} + ; ; SMIN ;