Index: llvm/include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- llvm/include/llvm/CodeGen/ISDOpcodes.h +++ llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -614,6 +614,17 @@ MULHU, MULHS, + /// AVGFLOORS/AVGFLOORU - Halving add - Add two integers using an integer of + /// type i[N+1], halving the result by shifting it one bit right. + /// shr(add(ext(X), ext(Y)), 1) + AVGFLOORS, + AVGFLOORU, + /// AVGCEILS/AVGCEILU - Rounding halving add - Add two integers using an + /// integer of type i[N+2], add 1 and halve the result by shifting it one bit + /// right. shr(add(ext(X), ext(Y), 1), 1) + AVGCEILS, + AVGCEILU, + // ABDS/ABDU - Absolute difference - Return the absolute difference between // two numbers interpreted as signed/unsigned. // i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1) Index: llvm/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/include/llvm/CodeGen/TargetLowering.h +++ llvm/include/llvm/CodeGen/TargetLowering.h @@ -2515,6 +2515,10 @@ case ISD::FMAXNUM_IEEE: case ISD::FMINIMUM: case ISD::FMAXIMUM: + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: + case ISD::AVGCEILS: + case ISD::AVGCEILU: return true; default: return false; } Index: llvm/include/llvm/Target/TargetSelectionDAG.td =================================================================== --- llvm/include/llvm/Target/TargetSelectionDAG.td +++ llvm/include/llvm/Target/TargetSelectionDAG.td @@ -365,6 +365,10 @@ [SDNPCommutative, SDNPAssociative]>; def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>; def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>; +def avgfloors : SDNode<"ISD::AVGFLOORS" , SDTIntBinOp, [SDNPCommutative]>; +def avgflooru : SDNode<"ISD::AVGFLOORU" , SDTIntBinOp, [SDNPCommutative]>; +def avgceils : SDNode<"ISD::AVGCEILS" , SDTIntBinOp, [SDNPCommutative]>; +def avgceilu : SDNode<"ISD::AVGCEILU" , SDTIntBinOp, [SDNPCommutative]>; def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>; def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>; def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>; Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12599,6 +12599,87 @@ return SDValue(); } +// Attempt to form one of the avg patterns from: +// truncate(shr(add(zext(OpB), zext(OpA)), 1)) +// Creating avgflooru/avgfloors/avgceilu/avgceils, with the ceiling having an +// extra rounding add: +// truncate(shr(add(zext(OpB), zext(OpA), 1), 1)) +// This starts at a truncate, meaning the shift will always be shl, as the top +// bits are known to not be demanded. +static SDValue performAvgCombine(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + + SDValue Shift = N->getOperand(0); + if (Shift.getOpcode() != ISD::SRL) + return SDValue(); + + // Is the right shift using an immediate value of 1? + ConstantSDNode *N1C = isConstOrConstSplat(Shift.getOperand(1)); + if (!N1C || !N1C->isOne()) + return SDValue(); + + // We are looking for an avgfloor + // add(ext, ext) + // or one of these as a avgceil + // add(add(ext, ext), 1) + // add(add(ext, 1), ext) + // add(ext, add(ext, 1)) + SDValue Add = Shift.getOperand(0); + if (Add.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue ExtendOpA = Add.getOperand(0); + SDValue ExtendOpB = Add.getOperand(1); + auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3) { + ConstantSDNode *ConstOp; + if ((ConstOp = isConstOrConstSplat(Op1)) && ConstOp->isOne()) { + ExtendOpA = Op2; + ExtendOpB = Op3; + return true; + } + if ((ConstOp = isConstOrConstSplat(Op2)) && ConstOp->isOne()) { + ExtendOpA = Op1; + ExtendOpB = Op3; + return true; + } + if ((ConstOp = isConstOrConstSplat(Op3)) && ConstOp->isOne()) { + ExtendOpA = Op1; + ExtendOpB = Op2; + return true; + } + return false; + }; + bool IsCeil = (ExtendOpA.getOpcode() == ISD::ADD && + MatchOperands(ExtendOpA.getOperand(0), ExtendOpA.getOperand(1), + ExtendOpB)) || + (ExtendOpB.getOpcode() == ISD::ADD && + MatchOperands(ExtendOpB.getOperand(0), ExtendOpB.getOperand(1), + ExtendOpA)); + + unsigned ExtendOpAOpc = ExtendOpA.getOpcode(); + unsigned ExtendOpBOpc = ExtendOpB.getOpcode(); + if (!(ExtendOpAOpc == ExtendOpBOpc && + (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND))) + return SDValue(); + + // Is the result of the right shift being truncated to the same value type as + // the original operands, OpA and OpB? + SDValue OpA = ExtendOpA.getOperand(0); + SDValue OpB = ExtendOpB.getOperand(0); + EVT OpAVT = OpA.getValueType(); + assert(ExtendOpA.getValueType() == ExtendOpB.getValueType()); + if (VT != OpAVT || OpAVT != OpB.getValueType()) + return SDValue(); + + bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND; + unsigned AVGOpc = IsSignExtend ? (IsCeil ? ISD::AVGCEILS : ISD::AVGFLOORS) + : (IsCeil ? ISD::AVGCEILU : ISD::AVGFLOORU); + if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(AVGOpc, VT)) + return SDValue(); + + return DAG.getNode(AVGOpc, SDLoc(N), VT, OpA, OpB); +} + SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -12885,6 +12966,8 @@ if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N)) return NewVSel; + if (SDValue M = performAvgCombine(N, DAG)) + return M; // Narrow a suitable binary operation with a non-opaque constant operand by // moving it ahead of the truncate. This is limited to pre-legalization Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -3287,6 +3287,10 @@ case ISD::USHLSAT: case ISD::ROTL: case ISD::ROTR: + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: + case ISD::AVGCEILS: + case ISD::AVGCEILU: // Vector-predicated binary op widening. Note that -- unlike the // unpredicated versions -- we don't have to worry about trapping on // operations like UDIV, FADD, etc., as we pass on the original vector Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -231,6 +231,10 @@ case ISD::MUL: return "mul"; case ISD::MULHU: return "mulhu"; case ISD::MULHS: return "mulhs"; + case ISD::AVGFLOORU: return "avgflooru"; + case ISD::AVGFLOORS: return "avgfloors"; + case ISD::AVGCEILU: return "avgceilu"; + case ISD::AVGCEILS: return "avgceils"; case ISD::ABDS: return "abds"; case ISD::ABDU: return "abdu"; case ISD::SDIV: return "sdiv"; Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -817,6 +817,12 @@ setOperationAction(ISD::SUBC, VT, Expand); setOperationAction(ISD::SUBE, VT, Expand); + // Halving adds + setOperationAction(ISD::AVGFLOORS, VT, Expand); + setOperationAction(ISD::AVGFLOORU, VT, Expand); + setOperationAction(ISD::AVGCEILS, VT, Expand); + setOperationAction(ISD::AVGCEILU, VT, Expand); + // Absolute difference setOperationAction(ISD::ABDS, VT, Expand); setOperationAction(ISD::ABDU, VT, Expand); Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -232,14 +232,6 @@ SADDV, UADDV, - // Vector halving addition - SHADD, - UHADD, - - // Vector rounding halving addition - SRHADD, - URHADD, - // Add Long Pairwise SADDLP, UADDLP, Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -870,7 +870,6 @@ setTargetDAGCombine(ISD::SIGN_EXTEND); setTargetDAGCombine(ISD::VECTOR_SPLICE); setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); - setTargetDAGCombine(ISD::TRUNCATE); setTargetDAGCombine(ISD::CONCAT_VECTORS); setTargetDAGCombine(ISD::INSERT_SUBVECTOR); setTargetDAGCombine(ISD::STORE); @@ -1047,6 +1046,10 @@ for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, MVT::v4i32}) { + setOperationAction(ISD::AVGFLOORS, VT, Legal); + setOperationAction(ISD::AVGFLOORU, VT, Legal); + setOperationAction(ISD::AVGCEILS, VT, Legal); + setOperationAction(ISD::AVGCEILU, VT, Legal); setOperationAction(ISD::ABDS, VT, Legal); setOperationAction(ISD::ABDU, VT, Legal); } @@ -2094,10 +2097,6 @@ MAKE_CASE(AArch64ISD::FCMLTz) MAKE_CASE(AArch64ISD::SADDV) MAKE_CASE(AArch64ISD::UADDV) - MAKE_CASE(AArch64ISD::SRHADD) - MAKE_CASE(AArch64ISD::URHADD) - MAKE_CASE(AArch64ISD::SHADD) - MAKE_CASE(AArch64ISD::UHADD) MAKE_CASE(AArch64ISD::SDOT) MAKE_CASE(AArch64ISD::UDOT) MAKE_CASE(AArch64ISD::SMINV) @@ -4371,9 +4370,9 @@ IntNo == Intrinsic::aarch64_neon_shadd); bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd || IntNo == Intrinsic::aarch64_neon_urhadd); - unsigned Opcode = - IsSignedAdd ? (IsRoundingAdd ? AArch64ISD::SRHADD : AArch64ISD::SHADD) - : (IsRoundingAdd ? AArch64ISD::URHADD : AArch64ISD::UHADD); + unsigned Opcode = IsSignedAdd + ? (IsRoundingAdd ? ISD::AVGCEILS : ISD::AVGFLOORS) + : (IsRoundingAdd ? ISD::AVGCEILU : ISD::AVGFLOORU); return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2)); } @@ -14247,89 +14246,6 @@ return SDValue(); } -// Attempt to form urhadd(OpA, OpB) from -// truncate(vlshr(sub(zext(OpB), xor(zext(OpA), Ones(ElemSizeInBits))), 1)) -// or uhadd(OpA, OpB) from truncate(vlshr(add(zext(OpA), zext(OpB)), 1)). -// The original form of the first expression is -// truncate(srl(add(zext(OpB), add(zext(OpA), 1)), 1)) and the -// (OpA + OpB + 1) subexpression will have been changed to (OpB - (~OpA)). -// Before this function is called the srl will have been lowered to -// AArch64ISD::VLSHR. -// This pass can also recognize signed variants of the patterns that use sign -// extension instead of zero extension and form a srhadd(OpA, OpB) or a -// shadd(OpA, OpB) from them. -static SDValue -performVectorTruncateCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, - SelectionDAG &DAG) { - EVT VT = N->getValueType(0); - - // Since we are looking for a right shift by a constant value of 1 and we are - // operating on types at least 16 bits in length (sign/zero extended OpA and - // OpB, which are at least 8 bits), it follows that the truncate will always - // discard the shifted-in bit and therefore the right shift will be logical - // regardless of the signedness of OpA and OpB. - SDValue Shift = N->getOperand(0); - if (Shift.getOpcode() != AArch64ISD::VLSHR) - return SDValue(); - - // Is the right shift using an immediate value of 1? - uint64_t ShiftAmount = Shift.getConstantOperandVal(1); - if (ShiftAmount != 1) - return SDValue(); - - SDValue ExtendOpA, ExtendOpB; - SDValue ShiftOp0 = Shift.getOperand(0); - unsigned ShiftOp0Opc = ShiftOp0.getOpcode(); - if (ShiftOp0Opc == ISD::SUB) { - - SDValue Xor = ShiftOp0.getOperand(1); - if (Xor.getOpcode() != ISD::XOR) - return SDValue(); - - // Is the XOR using a constant amount of all ones in the right hand side? - uint64_t C; - if (!isAllConstantBuildVector(Xor.getOperand(1), C)) - return SDValue(); - - unsigned ElemSizeInBits = VT.getScalarSizeInBits(); - APInt CAsAPInt(ElemSizeInBits, C); - if (CAsAPInt != APInt::getAllOnes(ElemSizeInBits)) - return SDValue(); - - ExtendOpA = Xor.getOperand(0); - ExtendOpB = ShiftOp0.getOperand(0); - } else if (ShiftOp0Opc == ISD::ADD) { - ExtendOpA = ShiftOp0.getOperand(0); - ExtendOpB = ShiftOp0.getOperand(1); - } else - return SDValue(); - - unsigned ExtendOpAOpc = ExtendOpA.getOpcode(); - unsigned ExtendOpBOpc = ExtendOpB.getOpcode(); - if (!(ExtendOpAOpc == ExtendOpBOpc && - (ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND))) - return SDValue(); - - // Is the result of the right shift being truncated to the same value type as - // the original operands, OpA and OpB? - SDValue OpA = ExtendOpA.getOperand(0); - SDValue OpB = ExtendOpB.getOperand(0); - EVT OpAVT = OpA.getValueType(); - assert(ExtendOpA.getValueType() == ExtendOpB.getValueType()); - if (!(VT == OpAVT && OpAVT == OpB.getValueType())) - return SDValue(); - - SDLoc DL(N); - bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND; - bool IsRHADD = ShiftOp0Opc == ISD::SUB; - unsigned HADDOpc = IsSignExtend - ? (IsRHADD ? AArch64ISD::SRHADD : AArch64ISD::SHADD) - : (IsRHADD ? AArch64ISD::URHADD : AArch64ISD::UHADD); - SDValue ResultHADD = DAG.getNode(HADDOpc, DL, VT, OpA, OpB); - - return ResultHADD; -} - static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { switch (Opcode) { case ISD::FADD: @@ -14432,20 +14348,20 @@ if (DCI.isBeforeLegalizeOps()) return SDValue(); - // Optimise concat_vectors of two [us]rhadds or [us]hadds that use extracted - // subvectors from the same original vectors. Combine these into a single - // [us]rhadd or [us]hadd that operates on the two original vectors. Example: - // (v16i8 (concat_vectors (v8i8 (urhadd (extract_subvector (v16i8 OpA, <0>), - // extract_subvector (v16i8 OpB, - // <0>))), - // (v8i8 (urhadd (extract_subvector (v16i8 OpA, <8>), - // extract_subvector (v16i8 OpB, - // <8>))))) + // Optimise concat_vectors of two [us]avgceils or [us]avgfloors that use + // extracted subvectors from the same original vectors. Combine these into a + // single [us]rhadd or [us]hadd that operates on the two original vectors. + // avgceil is the target independant name for rhadd, avgfloor is a hadd. + // Example: + // (concat_vectors (v8i8 (avgceils (extract_subvector (v16i8 OpA, <0>), + // extract_subvector (v16i8 OpB, <0>))), + // (v8i8 (avgceils (extract_subvector (v16i8 OpA, <8>), + // extract_subvector (v16i8 OpB, <8>))))) // -> - // (v16i8(urhadd(v16i8 OpA, v16i8 OpB))) + // (v16i8(avgceils(v16i8 OpA, v16i8 OpB))) if (N->getNumOperands() == 2 && N0Opc == N1Opc && - (N0Opc == AArch64ISD::URHADD || N0Opc == AArch64ISD::SRHADD || - N0Opc == AArch64ISD::UHADD || N0Opc == AArch64ISD::SHADD)) { + (N0Opc == ISD::AVGCEILU || N0Opc == ISD::AVGCEILS || + N0Opc == ISD::AVGFLOORU || N0Opc == ISD::AVGFLOORS)) { SDValue N00 = N0->getOperand(0); SDValue N01 = N0->getOperand(1); SDValue N10 = N1->getOperand(0); @@ -17927,8 +17843,6 @@ return performExtendCombine(N, DCI, DAG); case ISD::SIGN_EXTEND_INREG: return performSignExtendInRegCombine(N, DCI, DAG); - case ISD::TRUNCATE: - return performVectorTruncateCombine(N, DCI, DAG); case ISD::CONCAT_VECTORS: return performConcatVectorsCombine(N, DCI, DAG); case ISD::INSERT_SUBVECTOR: Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -630,11 +630,6 @@ def AArch64smaxv : SDNode<"AArch64ISD::SMAXV", SDT_AArch64UnaryVec>; def AArch64umaxv : SDNode<"AArch64ISD::UMAXV", SDT_AArch64UnaryVec>; -def AArch64srhadd : SDNode<"AArch64ISD::SRHADD", SDT_AArch64binvec>; -def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>; -def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>; -def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>; - def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs), [(abdu node:$lhs, node:$rhs), (int_aarch64_neon_uabd node:$lhs, node:$rhs)]>; @@ -4488,7 +4483,7 @@ defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", TriOpFrag<(add node:$LHS, (AArch64sabd node:$MHS, node:$RHS))> >; defm SABD : SIMDThreeSameVectorBHS<0,0b01110,"sabd", AArch64sabd>; -defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", AArch64shadd>; +defm SHADD : SIMDThreeSameVectorBHS<0,0b00000,"shadd", avgfloors>; defm SHSUB : SIMDThreeSameVectorBHS<0,0b00100,"shsub", int_aarch64_neon_shsub>; defm SMAXP : SIMDThreeSameVectorBHS<0,0b10100,"smaxp", int_aarch64_neon_smaxp>; defm SMAX : SIMDThreeSameVectorBHS<0,0b01100,"smax", smax>; @@ -4500,14 +4495,14 @@ defm SQRSHL : SIMDThreeSameVector<0,0b01011,"sqrshl", int_aarch64_neon_sqrshl>; defm SQSHL : SIMDThreeSameVector<0,0b01001,"sqshl", int_aarch64_neon_sqshl>; defm SQSUB : SIMDThreeSameVector<0,0b00101,"sqsub", int_aarch64_neon_sqsub>; -defm SRHADD : SIMDThreeSameVectorBHS<0,0b00010,"srhadd", AArch64srhadd>; +defm SRHADD : SIMDThreeSameVectorBHS<0,0b00010,"srhadd", avgceils>; defm SRSHL : SIMDThreeSameVector<0,0b01010,"srshl", int_aarch64_neon_srshl>; defm SSHL : SIMDThreeSameVector<0,0b01000,"sshl", int_aarch64_neon_sshl>; defm SUB : SIMDThreeSameVector<1,0b10000,"sub", sub>; defm UABA : SIMDThreeSameVectorBHSTied<1, 0b01111, "uaba", TriOpFrag<(add node:$LHS, (AArch64uabd node:$MHS, node:$RHS))> >; defm UABD : SIMDThreeSameVectorBHS<1,0b01110,"uabd", AArch64uabd>; -defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", AArch64uhadd>; +defm UHADD : SIMDThreeSameVectorBHS<1,0b00000,"uhadd", avgflooru>; defm UHSUB : SIMDThreeSameVectorBHS<1,0b00100,"uhsub", int_aarch64_neon_uhsub>; defm UMAXP : SIMDThreeSameVectorBHS<1,0b10100,"umaxp", int_aarch64_neon_umaxp>; defm UMAX : SIMDThreeSameVectorBHS<1,0b01100,"umax", umax>; @@ -4517,7 +4512,7 @@ defm UQRSHL : SIMDThreeSameVector<1,0b01011,"uqrshl", int_aarch64_neon_uqrshl>; defm UQSHL : SIMDThreeSameVector<1,0b01001,"uqshl", int_aarch64_neon_uqshl>; defm UQSUB : SIMDThreeSameVector<1,0b00101,"uqsub", int_aarch64_neon_uqsub>; -defm URHADD : SIMDThreeSameVectorBHS<1,0b00010,"urhadd", AArch64urhadd>; +defm URHADD : SIMDThreeSameVectorBHS<1,0b00010,"urhadd", avgceilu>; defm URSHL : SIMDThreeSameVector<1,0b01010,"urshl", int_aarch64_neon_urshl>; defm USHL : SIMDThreeSameVector<1,0b01000,"ushl", int_aarch64_neon_ushl>; defm SQRDMLAH : SIMDThreeSameVectorSQRDMLxHTiedHS<1,0b10000,"sqrdmlah",