Index: lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -174,6 +174,8 @@ SDValue PromoteLegalFP_TO_INT(SDValue LegalOp, EVT DestVT, bool isSigned, const SDLoc &dl); + SDValue ExpandABS(SDNode *Node, const SDLoc &dl); + SDValue ExpandBITREVERSE(SDValue Op, const SDLoc &dl); SDValue ExpandBSWAP(SDValue Op, const SDLoc &dl); SDValue ExpandBitCount(unsigned Opc, SDValue Op, const SDLoc &dl); @@ -2588,6 +2590,19 @@ return DAG.getNode(ISD::TRUNCATE, dl, DestVT, Operation); } +/// Expand the ABS node to SRA + ADD + XOR nodes. +/// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size)) +SDValue SelectionDAGLegalize::ExpandABS(SDNode *Node, const SDLoc &dl) { + EVT VT = Node->getValueType(0); + SDValue N0 = Node->getOperand(0); + + SDValue Shift = DAG.getNode(ISD::SRA, dl, VT, N0, + DAG.getConstant(VT.getSizeInBits() - 1, dl, VT)); + SDValue Add = DAG.getNode(ISD::ADD, dl, VT, N0, Shift); + + return DAG.getNode(ISD::XOR, dl, VT, Add, Shift); +} + /// Legalize a BITREVERSE scalar/vector operation as a series of mask + shifts. SDValue SelectionDAGLegalize::ExpandBITREVERSE(SDValue Op, const SDLoc &dl) { EVT VT = Op.getValueType(); @@ -2836,6 +2851,9 @@ SDValue Tmp1, Tmp2, Tmp3, Tmp4; bool NeedInvert; switch (Node->getOpcode()) { + case ISD::ABS: + Results.push_back(ExpandABS(Node, dl)); + break; case ISD::CTPOP: case ISD::CTLZ: case ISD::CTLZ_ZERO_UNDEF: Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -1385,6 +1385,9 @@ case ISD::BITREVERSE: ExpandIntRes_BITREVERSE(N, Lo, Hi); break; case ISD::BSWAP: ExpandIntRes_BSWAP(N, Lo, Hi); break; case ISD::Constant: ExpandIntRes_Constant(N, Lo, Hi); break; + case ISD::ABS: + ExpandIntRes_ABS(N, Lo, Hi); + break; case ISD::CTLZ_ZERO_UNDEF: case ISD::CTLZ: ExpandIntRes_CTLZ(N, Lo, Hi); break; case ISD::CTPOP: ExpandIntRes_CTPOP(N, Lo, Hi); break; @@ -2129,6 +2132,37 @@ IsOpaque); } +void DAGTypeLegalizer::ExpandIntRes_ABS(SDNode *N, SDValue &Lo, SDValue &Hi) { + SDLoc dl(N); + SDValue Tmp; + + GetExpandedInteger(N->getOperand(0), Lo, Tmp); + EVT OpType = Tmp.getValueType(); + + if (TLI.isOperationLegalOrCustom(ISD::ADDCARRY, OpType) && + TLI.isOperationLegalOrCustom(ISD::UADDO, OpType)) { + unsigned OpTypeBits = OpType.getSizeInBits(); + SDVTList VTList = DAG.getVTList(OpType, MVT::i1); + Hi = Tmp; + Tmp = DAG.getNode( + ISD::SRA, dl, OpType, Tmp, + DAG.getConstant(OpTypeBits - 1, dl, + TLI.getShiftAmountTy(OpType, DAG.getDataLayout()))); + Lo = DAG.getNode(ISD::UADDO, dl, VTList, Tmp, Lo); + Hi = DAG.getNode(ISD::ADDCARRY, dl, VTList, Tmp, Hi, + SDValue(Lo.getNode(), 1)); + Hi = DAG.getNode(ISD::XOR, dl, OpType, Tmp, Hi); + Lo = DAG.getNode(ISD::XOR, dl, OpType, Tmp, Lo); + } else { + Hi = DAG.getNode(ISD::ABS, dl, OpType, Tmp); + // Lo = Hi==abs(Hi) ? Lo : 0 - Lo; + Lo = DAG.getSelectCC( + dl, Tmp, Hi, Lo, + DAG.getNode(ISD::SUB, dl, OpType, DAG.getConstant(0, dl, OpType), Lo), + ISD::SETEQ); + } +} + void DAGTypeLegalizer::ExpandIntRes_CTLZ(SDNode *N, SDValue &Lo, SDValue &Hi) { SDLoc dl(N); Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -382,6 +382,7 @@ void ExpandIntRes_AssertSext (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandIntRes_AssertZext (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandIntRes_Constant (SDNode *N, SDValue &Lo, SDValue &Hi); + void ExpandIntRes_ABS (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandIntRes_CTLZ (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandIntRes_CTPOP (SDNode *N, SDValue &Lo, SDValue &Hi); void ExpandIntRes_CTTZ (SDNode *N, SDValue &Lo, SDValue &Hi); Index: lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -116,6 +116,11 @@ /// the remaining lanes, finally bitcasting to the proper type. SDValue ExpandZERO_EXTEND_VECTOR_INREG(SDValue Op); + /// Implement expand-based legalization of ABS vector operations. + /// + /// (ABS x) --> (XOR (ADD x, (SRA x, sizeof(x)-1)), (SRA x, sizeof(x)-1)). + SDValue ExpandABS(SDValue Op); + /// Expand bswap of vectors into a shuffle if legal. SDValue ExpandBSWAP(SDValue Op); @@ -335,6 +340,7 @@ case ISD::AND: case ISD::OR: case ISD::XOR: + case ISD::ABS: case ISD::SHL: case ISD::SRA: case ISD::SRL: @@ -714,6 +720,8 @@ return ExpandSIGN_EXTEND_VECTOR_INREG(Op); case ISD::ZERO_EXTEND_VECTOR_INREG: return ExpandZERO_EXTEND_VECTOR_INREG(Op); + case ISD::ABS: + return ExpandABS(Op); case ISD::BSWAP: return ExpandBSWAP(Op); case ISD::VSELECT: @@ -916,6 +924,19 @@ ShuffleMask.push_back((I * ScalarSizeInBytes) + J); } +SDValue VectorLegalizer::ExpandABS(SDValue Op) { + EVT VT = Op.getValueType(); + SDValue N0 = Op.getOperand(0); + SDLoc DL(Op); + + SDValue Shift = + DAG.getNode(ISD::SRA, DL, VT, N0, + DAG.getConstant(VT.getScalarSizeInBits() - 1, DL, VT)); + SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N0, Shift); + + return DAG.getNode(ISD::XOR, DL, VT, Add, Shift); +} + SDValue VectorLegalizer::ExpandBSWAP(SDValue Op) { EVT VT = Op.getValueType(); Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -681,6 +681,7 @@ SplitVecRes_ExtVecInRegOp(N, Lo, Hi); break; + case ISD::ABS: case ISD::BITREVERSE: case ISD::BSWAP: case ISD::CTLZ: Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2915,6 +2915,8 @@ ISD::NodeType OpCode = Cond.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT; + bool IsUnary = false; + // Min/max matching is only viable if all output VTs are the same. if (std::equal(ValueVTs.begin(), ValueVTs.end(), ValueVTs.begin())) { EVT VT = ValueVTs[0]; @@ -2975,10 +2977,17 @@ break; } break; + case SPF_ABS: + Opc = VT.isInteger() ? ISD::ABS : ISD::FABS; + IsUnary = true; + break; + case SPF_NABS: + // TODO: What DAG should be produced? + break; default: break; } - if (Opc != ISD::DELETED_NODE && + if (!IsUnary && Opc != ISD::DELETED_NODE && (TLI.isOperationLegalOrCustom(Opc, VT) || (UseScalarMinMax && TLI.isOperationLegalOrCustom(Opc, VT.getScalarType()))) && @@ -2991,15 +3000,30 @@ RHSVal = getValue(RHS); BaseOps = {}; } + + if (IsUnary) { + OpCode = Opc; + LHSVal = getValue(LHS); + BaseOps = {}; + } } - for (unsigned i = 0; i != NumValues; ++i) { - SmallVector Ops(BaseOps.begin(), BaseOps.end()); - Ops.push_back(SDValue(LHSVal.getNode(), LHSVal.getResNo() + i)); - Ops.push_back(SDValue(RHSVal.getNode(), RHSVal.getResNo() + i)); - Values[i] = DAG.getNode(OpCode, getCurSDLoc(), - LHSVal.getNode()->getValueType(LHSVal.getResNo()+i), - Ops); + if (IsUnary) { + for (unsigned i = 0; i != NumValues; ++i) { + Values[i] = + DAG.getNode(OpCode, getCurSDLoc(), + LHSVal.getNode()->getValueType(LHSVal.getResNo() + i), + SDValue(LHSVal.getNode(), LHSVal.getResNo() + i)); + } + } else { + for (unsigned i = 0; i != NumValues; ++i) { + SmallVector Ops(BaseOps.begin(), BaseOps.end()); + Ops.push_back(SDValue(LHSVal.getNode(), LHSVal.getResNo() + i)); + Ops.push_back(SDValue(RHSVal.getNode(), RHSVal.getResNo() + i)); + Values[i] = DAG.getNode( + OpCode, getCurSDLoc(), + LHSVal.getNode()->getValueType(LHSVal.getResNo() + i), Ops); + } } setValue(&I, DAG.getNode(ISD::MERGE_VALUES, getCurSDLoc(), Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -31846,66 +31846,16 @@ return Op; } -// Given a select, detect the following pattern: -// 1: %2 = zext %0 to -// 2: %3 = zext %1 to -// 3: %4 = sub nsw %2, %3 -// 4: %5 = icmp sgt %4, [0 x N] or [-1 x N] -// 5: %6 = sub nsw zeroinitializer, %4 -// 6: %7 = select %5, %4, %6 +// Given a ABS node, detect the following pattern: +// (ABS (SUB (ZERO_EXTEND a), (ZERO_EXTEND b))). // This is useful as it is the input into a SAD pattern. -static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0, - SDValue &Op1) { - // Check the condition of the select instruction is greater-than. - SDValue SetCC = Select->getOperand(0); - if (SetCC.getOpcode() != ISD::SETCC) - return false; - ISD::CondCode CC = cast(SetCC.getOperand(2))->get(); - if (CC != ISD::SETGT && CC != ISD::SETLT) - return false; - - SDValue SelectOp1 = Select->getOperand(1); - SDValue SelectOp2 = Select->getOperand(2); - - // The following instructions assume SelectOp1 is the subtraction operand - // and SelectOp2 is the negation operand. - // In the case of SETLT this is the other way around. - if (CC == ISD::SETLT) - std::swap(SelectOp1, SelectOp2); - - // The second operand of the select should be the negation of the first - // operand, which is implemented as 0 - SelectOp1. - if (!(SelectOp2.getOpcode() == ISD::SUB && - ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) && - SelectOp2.getOperand(1) == SelectOp1)) - return false; - - // The first operand of SetCC is the first operand of the select, which is the - // difference between the two input vectors. - if (SetCC.getOperand(0) != SelectOp1) - return false; - - // In SetLT case, The second operand of the comparison can be either 1 or 0. - APInt SplatVal; - if ((CC == ISD::SETLT) && - !((ISD::isConstantSplatVector(SetCC.getOperand(1).getNode(), SplatVal) && - SplatVal.isOneValue()) || - (ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode())))) - return false; - - // In SetGT case, The second operand of the comparison can be either -1 or 0. - if ((CC == ISD::SETGT) && - !(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) || - ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode()))) - return false; - - // The first operand of the select is the difference between the two input - // vectors. - if (SelectOp1.getOpcode() != ISD::SUB) +static bool detectZextAbsDiff(const SDValue &Abs, SDValue &Op0, SDValue &Op1) { + SDValue AbsOp1 = Abs->getOperand(0); + if (AbsOp1.getOpcode() != ISD::SUB) return false; - Op0 = SelectOp1.getOperand(0); - Op1 = SelectOp1.getOperand(1); + Op0 = AbsOp1.getOperand(0); + Op1 = AbsOp1.getOperand(1); // Check if the operands of the sub are zero-extended from vectors of i8. if (Op0.getOpcode() != ISD::ZERO_EXTEND || @@ -32141,7 +32091,7 @@ // If there was a match, we want Root to be a select that is the root of an // abs-diff pattern. - if (!Root || (Root.getOpcode() != ISD::VSELECT)) + if (!Root || Root.getOpcode() != ISD::ABS) return SDValue(); // Check whether we have an abs-diff pattern feeding into the select. @@ -38635,10 +38585,10 @@ // We know N is a reduction add, which means one of its operands is a phi. // To match SAD, we need the other operand to be a vector select. SDValue SelectOp, Phi; - if (Op0.getOpcode() == ISD::VSELECT) { + if (Op0.getOpcode() == ISD::ABS) { SelectOp = Op0; Phi = Op1; - } else if (Op1.getOpcode() == ISD::VSELECT) { + } else if (Op1.getOpcode() == ISD::ABS) { SelectOp = Op1; Phi = Op0; } else Index: test/CodeGen/X86/combine-abs.ll =================================================================== --- test/CodeGen/X86/combine-abs.ll +++ test/CodeGen/X86/combine-abs.ll @@ -67,9 +67,6 @@ ; AVX2-LABEL: combine_v4i64_abs_abs: ; AVX2: # %bb.0: ; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1 -; AVX2-NEXT: vpcmpgtq %ymm0, %ymm1, %ymm2 -; AVX2-NEXT: vpaddq %ymm2, %ymm0, %ymm0 -; AVX2-NEXT: vpxor %ymm2, %ymm0, %ymm0 ; AVX2-NEXT: vpcmpgtq %ymm0, %ymm1, %ymm1 ; AVX2-NEXT: vpaddq %ymm1, %ymm0, %ymm0 ; AVX2-NEXT: vpxor %ymm1, %ymm0, %ymm0