Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9359,7 +9359,7 @@ // (srl (mul (zext i32:$a to i64), (zext i32:$a to i64)), 32) -> (mulhu $a, $b) // (sra (mul (sext i32:$a to i64), (sext i32:$a to i64)), 32) -> (mulhs $a, $b) static SDValue combineShiftToMULH(SDNode *N, SelectionDAG &DAG, - const TargetLowering &TLI) { + const TargetLowering &TLI, bool LegalTypes) { assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) && "SRL or SRA node is required here!"); @@ -9452,7 +9452,8 @@ unsigned MulhOpcode = IsSignExt ? ISD::MULHS : ISD::MULHU; // Combine to mulh if mulh is legal/custom for the narrow type on the target. - if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT)) + if (!TLI.isOperationLegalOrCustom(MulhOpcode, NarrowVT) && + (LegalTypes || !TLI.isOperationCustom(MulhOpcode, NarrowVT))) return SDValue(); SDValue Result = @@ -9673,7 +9674,7 @@ // Try to transform this shift into a multiply-high if // it matches the appropriate pattern detected in combineShiftToMULH. - if (SDValue MULH = combineShiftToMULH(N, DAG, TLI)) + if (SDValue MULH = combineShiftToMULH(N, DAG, TLI, LegalTypes)) return MULH; // Attempt to convert a sra of a load into a narrower sign-extending load. @@ -9932,7 +9933,7 @@ // Try to transform this shift into a multiply-high if // it matches the appropriate pattern detected in combineShiftToMULH. - if (SDValue MULH = combineShiftToMULH(N, DAG, TLI)) + if (SDValue MULH = combineShiftToMULH(N, DAG, TLI, LegalTypes)) return MULH; return SDValue(); @@ -10087,7 +10088,7 @@ // (ABS (SUB (EXTEND a), (EXTEND b))). // Generates UABD/SABD instruction. static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG, - const TargetLowering &TLI) { + const TargetLowering &TLI, bool LegalTypes) { SDValue AbsOp1 = N->getOperand(0); SDValue Op0, Op1; @@ -10111,7 +10112,8 @@ // fold abs(sext(x) - sext(y)) -> zext(abds(x, y)) // fold abs(zext(x) - zext(y)) -> zext(abdu(x, y)) // NOTE: Extensions must be equivalent. - if (VT1 == VT2 && TLI.isOperationLegalOrCustom(ABDOpcode, VT1)) { + if (VT1 == VT2 && (TLI.isOperationLegalOrCustom(ABDOpcode, VT1) || + (!LegalTypes && TLI.isOperationCustom(ABDOpcode, VT1)))) { Op0 = Op0.getOperand(0); Op1 = Op1.getOperand(0); SDValue ABD = DAG.getNode(ABDOpcode, SDLoc(N), VT1, Op0, Op1); @@ -10140,7 +10142,7 @@ if (DAG.SignBitIsZero(N0)) return N0; - if (SDValue ABD = combineABSToABD(N, DAG, TLI)) + if (SDValue ABD = combineABSToABD(N, DAG, TLI, LegalTypes)) return ABD; // fold (abs (sign_extend_inreg x)) -> (zero_extend (abs (truncate x))) Index: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -932,11 +932,12 @@ // Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1). // or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1). -static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG, +static SDValue combineShiftToAVG(SDValue Op, + TargetLowering::TargetLoweringOpt &TLO, const TargetLowering &TLI, const APInt &DemandedBits, - const APInt &DemandedElts, - unsigned Depth) { + const APInt &DemandedElts, unsigned Depth) { + SelectionDAG &DAG = TLO.DAG; assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) && "SRL or SRA node is required here!"); // Is the right shift using an immediate value of 1? @@ -1045,7 +1046,8 @@ EVT NVT = EVT::getIntegerVT(*DAG.getContext(), PowerOf2Ceil(MinWidth)); if (VT.isVector()) NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount()); - if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT)) + if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT) && + (TLO.LegalTypes() || !TLI.isOperationCustom(AVGOpc, NVT))) return SDValue(); SDLoc DL(Op); @@ -1829,7 +1831,7 @@ EVT ShiftVT = Op1.getValueType(); // Try to match AVG patterns. - if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits, + if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits, DemandedElts, Depth + 1)) return TLO.CombineTo(Op, AVG); @@ -1910,7 +1912,7 @@ return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1)); // Try to match AVG patterns. - if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits, + if (SDValue AVG = combineShiftToAVG(Op, TLO, *this, DemandedBits, DemandedElts, Depth + 1)) return TLO.CombineTo(Op, AVG); Index: llvm/lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- llvm/lib/Target/ARM/ARMISelLowering.cpp +++ llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -434,6 +434,19 @@ } } + // Custom extend some nodes so that the generic combines fire on smaller than + // legal types. + for (auto VT : {MVT::v8i8, MVT::v4i8, MVT::v4i16}) { + setOperationAction(ISD::MULHS, VT, Custom); + setOperationAction(ISD::MULHU, VT, Custom); + setOperationAction(ISD::AVGFLOORS, VT, Custom); + setOperationAction(ISD::AVGFLOORU, VT, Custom); + setOperationAction(ISD::AVGCEILS, VT, Custom); + setOperationAction(ISD::AVGCEILU, VT, Custom); + setOperationAction(ISD::ABDS, VT, Custom); + setOperationAction(ISD::ABDU, VT, Custom); + } + // Predicate types const MVT pTypes[] = {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1}; for (auto VT : pTypes) { @@ -5100,6 +5113,28 @@ return DAG.getNode(ISD::TRUNCATE, dl, VT, Add); } +// Custom lower MULH, ABD, HADD and RHADD nodes that are smaller than legal, +// using a bitcast to a larger legal type upon which we perform the bitcast. +// This allow DAG combine to recognize the nodes where it usually would not. +static SDValue LowerBinopWithBitcast(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + + assert((VT == MVT::v4i8 || VT == MVT::v8i8 || VT == MVT::v4i16) && + "Expected smaller than legal type!"); + + EVT ExtVT = VT.getVectorNumElements() == 4 ? MVT::v4i32 : MVT::v8i16; + EVT BinOpVT = VT.getScalarType() == MVT::i8 ? MVT::v16i8 : MVT::v8i16; + + SDLoc DL(N); + SDValue Ext0 = DAG.getNode(ISD::ANY_EXTEND, DL, ExtVT, N->getOperand(0)); + SDValue Ext1 = DAG.getNode(ISD::ANY_EXTEND, DL, ExtVT, N->getOperand(1)); + SDValue BC0 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, BinOpVT, Ext0); + SDValue BC1 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, BinOpVT, Ext1); + SDValue BinOp = DAG.getNode(N->getOpcode(), DL, BinOpVT, BC0, BC1); + SDValue BC2 = DAG.getNode(ARMISD::VECTOR_REG_CAST, DL, ExtVT, BinOp); + return DAG.getNode(ISD::TRUNCATE, DL, VT, BC2); +} + SDValue ARMTargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const { SDValue Cond = Op.getOperand(0); SDValue SelectTrue = Op.getOperand(1); @@ -10545,6 +10580,16 @@ case ISD::USUBSAT: Res = LowerADDSUBSAT(SDValue(N, 0), DAG, Subtarget); break; + case ISD::MULHS: + case ISD::MULHU: + case ISD::ABDS: + case ISD::ABDU: + case ISD::AVGFLOORS: + case ISD::AVGFLOORU: + case ISD::AVGCEILS: + case ISD::AVGCEILU: + Res = LowerBinopWithBitcast(N, DAG); + break; case ISD::READCYCLECOUNTER: ReplaceREADCYCLECOUNTER(N, Results, DAG, Subtarget); return; Index: llvm/test/CodeGen/Thumb2/mve-vabdus.ll =================================================================== --- llvm/test/CodeGen/Thumb2/mve-vabdus.ll +++ llvm/test/CodeGen/Thumb2/mve-vabdus.ll @@ -19,9 +19,8 @@ define arm_aapcs_vfpcc <8 x i8> @vabd_v8s8(<8 x i8> %src1, <8 x i8> %src2) { ; CHECK-LABEL: vabd_v8s8: ; CHECK: @ %bb.0: -; CHECK-NEXT: vmovlb.s8 q1, q1 -; CHECK-NEXT: vmovlb.s8 q0, q0 -; CHECK-NEXT: vabd.s16 q0, q0, q1 +; CHECK-NEXT: vabd.s8 q0, q0, q1 +; CHECK-NEXT: vmovlb.u8 q0, q0 ; CHECK-NEXT: bx lr %sextsrc1 = sext <8 x i8> %src1 to <8 x i16> %sextsrc2 = sext <8 x i8> %src2 to <8 x i16> @@ -36,12 +35,9 @@ define arm_aapcs_vfpcc <4 x i8> @vabd_v4s8(<4 x i8> %src1, <4 x i8> %src2) { ; CHECK-LABEL: vabd_v4s8: ; CHECK: @ %bb.0: -; CHECK-NEXT: vmovlb.s8 q1, q1 -; CHECK-NEXT: vmovlb.s8 q0, q0 -; CHECK-NEXT: vmovlb.s16 q1, q1 -; CHECK-NEXT: vmovlb.s16 q0, q0 -; CHECK-NEXT: vsub.i32 q0, q0, q1 -; CHECK-NEXT: vabs.s32 q0, q0 +; CHECK-NEXT: vabd.s8 q0, q0, q1 +; CHECK-NEXT: vmov.i32 q1, #0xff +; CHECK-NEXT: vand q0, q0, q1 ; CHECK-NEXT: bx lr %sextsrc1 = sext <4 x i8> %src1 to <4 x i16> %sextsrc2 = sext <4 x i8> %src2 to <4 x i16> @@ -71,9 +67,8 @@ define arm_aapcs_vfpcc <4 x i16> @vabd_v4s16(<4 x i16> %src1, <4 x i16> %src2) { ; CHECK-LABEL: vabd_v4s16: ; CHECK: @ %bb.0: -; CHECK-NEXT: vmovlb.s16 q1, q1 -; CHECK-NEXT: vmovlb.s16 q0, q0 -; CHECK-NEXT: vabd.s32 q0, q0, q1 +; CHECK-NEXT: vabd.s16 q0, q0, q1 +; CHECK-NEXT: vmovlb.u16 q0, q0 ; CHECK-NEXT: bx lr %sextsrc1 = sext <4 x i16> %src1 to <4 x i32> %sextsrc2 = sext <4 x i16> %src2 to <4 x i32> @@ -152,9 +147,8 @@ define arm_aapcs_vfpcc <8 x i8> @vabd_v8u8(<8 x i8> %src1, <8 x i8> %src2) { ; CHECK-LABEL: vabd_v8u8: ; CHECK: @ %bb.0: -; CHECK-NEXT: vmovlb.u8 q1, q1 +; CHECK-NEXT: vabd.u8 q0, q0, q1 ; CHECK-NEXT: vmovlb.u8 q0, q0 -; CHECK-NEXT: vabd.u16 q0, q0, q1 ; CHECK-NEXT: bx lr %zextsrc1 = zext <8 x i8> %src1 to <8 x i16> %zextsrc2 = zext <8 x i8> %src2 to <8 x i16> @@ -169,11 +163,9 @@ define arm_aapcs_vfpcc <4 x i8> @vabd_v4u8(<4 x i8> %src1, <4 x i8> %src2) { ; CHECK-LABEL: vabd_v4u8: ; CHECK: @ %bb.0: -; CHECK-NEXT: vmov.i32 q2, #0xff -; CHECK-NEXT: vand q1, q1, q2 -; CHECK-NEXT: vand q0, q0, q2 -; CHECK-NEXT: vsub.i32 q0, q0, q1 -; CHECK-NEXT: vabs.s32 q0, q0 +; CHECK-NEXT: vabd.u8 q0, q0, q1 +; CHECK-NEXT: vmov.i32 q1, #0xff +; CHECK-NEXT: vand q0, q0, q1 ; CHECK-NEXT: bx lr %zextsrc1 = zext <4 x i8> %src1 to <4 x i16> %zextsrc2 = zext <4 x i8> %src2 to <4 x i16> @@ -203,9 +195,8 @@ define arm_aapcs_vfpcc <4 x i16> @vabd_v4u16(<4 x i16> %src1, <4 x i16> %src2) { ; CHECK-LABEL: vabd_v4u16: ; CHECK: @ %bb.0: -; CHECK-NEXT: vmovlb.u16 q1, q1 +; CHECK-NEXT: vabd.u16 q0, q0, q1 ; CHECK-NEXT: vmovlb.u16 q0, q0 -; CHECK-NEXT: vabd.u32 q0, q0, q1 ; CHECK-NEXT: bx lr %zextsrc1 = zext <4 x i16> %src1 to <4 x i32> %zextsrc2 = zext <4 x i16> %src2 to <4 x i32> Index: llvm/test/CodeGen/Thumb2/mve-vhadd.ll =================================================================== --- llvm/test/CodeGen/Thumb2/mve-vhadd.ll +++ llvm/test/CodeGen/Thumb2/mve-vhadd.ll @@ -49,10 +49,8 @@ define arm_aapcs_vfpcc <4 x i16> @vhaddu_v4i16(<4 x i16> %s0, <4 x i16> %s1) { ; CHECK-LABEL: vhaddu_v4i16: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmovlb.u16 q1, q1 +; CHECK-NEXT: vhadd.u16 q0, q0, q1 ; CHECK-NEXT: vmovlb.u16 q0, q0 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: vshr.u32 q0, q0, #1 ; CHECK-NEXT: bx lr entry: %s0s = zext <4 x i16> %s0 to <4 x i32> @@ -114,11 +112,9 @@ define arm_aapcs_vfpcc <4 x i8> @vhaddu_v4i8(<4 x i8> %s0, <4 x i8> %s1) { ; CHECK-LABEL: vhaddu_v4i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q2, #0xff -; CHECK-NEXT: vand q1, q1, q2 -; CHECK-NEXT: vand q0, q0, q2 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: vshr.u32 q0, q0, #1 +; CHECK-NEXT: vhadd.u8 q0, q0, q1 +; CHECK-NEXT: vmov.i32 q1, #0xff +; CHECK-NEXT: vand q0, q0, q1 ; CHECK-NEXT: bx lr entry: %s0s = zext <4 x i8> %s0 to <4 x i16> @@ -149,10 +145,8 @@ define arm_aapcs_vfpcc <8 x i8> @vhaddu_v8i8(<8 x i8> %s0, <8 x i8> %s1) { ; CHECK-LABEL: vhaddu_v8i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmovlb.u8 q1, q1 +; CHECK-NEXT: vhadd.u8 q0, q0, q1 ; CHECK-NEXT: vmovlb.u8 q0, q0 -; CHECK-NEXT: vadd.i16 q0, q0, q1 -; CHECK-NEXT: vshr.u16 q0, q0, #1 ; CHECK-NEXT: bx lr entry: %s0s = zext <8 x i8> %s0 to <8 x i16> @@ -244,12 +238,8 @@ define arm_aapcs_vfpcc <4 x i16> @vrhaddu_v4i16(<4 x i16> %s0, <4 x i16> %s1) { ; CHECK-LABEL: vrhaddu_v4i16: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmovlb.u16 q1, q1 +; CHECK-NEXT: vrhadd.u16 q0, q0, q1 ; CHECK-NEXT: vmovlb.u16 q0, q0 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: movs r0, #1 -; CHECK-NEXT: vadd.i32 q0, q0, r0 -; CHECK-NEXT: vshr.u32 q0, q0, #1 ; CHECK-NEXT: bx lr entry: %s0s = zext <4 x i16> %s0 to <4 x i32> @@ -317,13 +307,9 @@ define arm_aapcs_vfpcc <4 x i8> @vrhaddu_v4i8(<4 x i8> %s0, <4 x i8> %s1) { ; CHECK-LABEL: vrhaddu_v4i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q2, #0xff -; CHECK-NEXT: movs r0, #1 -; CHECK-NEXT: vand q1, q1, q2 -; CHECK-NEXT: vand q0, q0, q2 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: vadd.i32 q0, q0, r0 -; CHECK-NEXT: vshr.u32 q0, q0, #1 +; CHECK-NEXT: vrhadd.u8 q0, q0, q1 +; CHECK-NEXT: vmov.i32 q1, #0xff +; CHECK-NEXT: vand q0, q0, q1 ; CHECK-NEXT: bx lr entry: %s0s = zext <4 x i8> %s0 to <4 x i16> @@ -358,12 +344,8 @@ define arm_aapcs_vfpcc <8 x i8> @vrhaddu_v8i8(<8 x i8> %s0, <8 x i8> %s1) { ; CHECK-LABEL: vrhaddu_v8i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmovlb.u8 q1, q1 +; CHECK-NEXT: vrhadd.u8 q0, q0, q1 ; CHECK-NEXT: vmovlb.u8 q0, q0 -; CHECK-NEXT: vadd.i16 q0, q0, q1 -; CHECK-NEXT: movs r0, #1 -; CHECK-NEXT: vadd.i16 q0, q0, r0 -; CHECK-NEXT: vshr.u16 q0, q0, #1 ; CHECK-NEXT: bx lr entry: %s0s = zext <8 x i8> %s0 to <8 x i16> Index: llvm/test/CodeGen/Thumb2/mve-vmulh.ll =================================================================== --- llvm/test/CodeGen/Thumb2/mve-vmulh.ll +++ llvm/test/CodeGen/Thumb2/mve-vmulh.ll @@ -74,8 +74,8 @@ define arm_aapcs_vfpcc <4 x i16> @vmulhs_v4i16(<4 x i16> %s0, <4 x i16> %s1) { ; CHECK-LABEL: vmulhs_v4i16: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmullb.s16 q0, q0, q1 -; CHECK-NEXT: vshr.s32 q0, q0, #16 +; CHECK-NEXT: vmulh.s16 q0, q0, q1 +; CHECK-NEXT: vmovlb.s16 q0, q0 ; CHECK-NEXT: bx lr entry: %s0s = sext <4 x i16> %s0 to <4 x i32> @@ -89,8 +89,8 @@ define arm_aapcs_vfpcc <4 x i16> @vmulhu_v4i16(<4 x i16> %s0, <4 x i16> %s1) { ; CHECK-LABEL: vmulhu_v4i16: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmullb.u16 q0, q0, q1 -; CHECK-NEXT: vshr.u32 q0, q0, #16 +; CHECK-NEXT: vmulh.u16 q0, q0, q1 +; CHECK-NEXT: vmovlb.u16 q0, q0 ; CHECK-NEXT: bx lr entry: %s0s = zext <4 x i16> %s0 to <4 x i32> @@ -132,12 +132,9 @@ define arm_aapcs_vfpcc <4 x i8> @vmulhs_v4i8(<4 x i8> %s0, <4 x i8> %s1) { ; CHECK-LABEL: vmulhs_v4i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmovlb.s8 q1, q1 +; CHECK-NEXT: vmulh.s8 q0, q0, q1 ; CHECK-NEXT: vmovlb.s8 q0, q0 -; CHECK-NEXT: vmovlb.s16 q1, q1 ; CHECK-NEXT: vmovlb.s16 q0, q0 -; CHECK-NEXT: vmul.i32 q0, q0, q1 -; CHECK-NEXT: vshr.s32 q0, q0, #8 ; CHECK-NEXT: bx lr entry: %s0s = sext <4 x i8> %s0 to <4 x i16> @@ -151,11 +148,9 @@ define arm_aapcs_vfpcc <4 x i8> @vmulhu_v4i8(<4 x i8> %s0, <4 x i8> %s1) { ; CHECK-LABEL: vmulhu_v4i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmov.i32 q2, #0xff -; CHECK-NEXT: vand q1, q1, q2 -; CHECK-NEXT: vand q0, q0, q2 -; CHECK-NEXT: vmul.i32 q0, q0, q1 -; CHECK-NEXT: vshr.u32 q0, q0, #8 +; CHECK-NEXT: vmulh.u8 q0, q0, q1 +; CHECK-NEXT: vmov.i32 q1, #0xff +; CHECK-NEXT: vand q0, q0, q1 ; CHECK-NEXT: bx lr entry: %s0s = zext <4 x i8> %s0 to <4 x i16> @@ -169,8 +164,8 @@ define arm_aapcs_vfpcc <8 x i8> @vmulhs_v8i8(<8 x i8> %s0, <8 x i8> %s1) { ; CHECK-LABEL: vmulhs_v8i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmullb.s8 q0, q0, q1 -; CHECK-NEXT: vshr.s16 q0, q0, #8 +; CHECK-NEXT: vmulh.s8 q0, q0, q1 +; CHECK-NEXT: vmovlb.s8 q0, q0 ; CHECK-NEXT: bx lr entry: %s0s = sext <8 x i8> %s0 to <8 x i16> @@ -184,8 +179,8 @@ define arm_aapcs_vfpcc <8 x i8> @vmulhu_v8i8(<8 x i8> %s0, <8 x i8> %s1) { ; CHECK-LABEL: vmulhu_v8i8: ; CHECK: @ %bb.0: @ %entry -; CHECK-NEXT: vmullb.u8 q0, q0, q1 -; CHECK-NEXT: vshr.u16 q0, q0, #8 +; CHECK-NEXT: vmulh.u8 q0, q0, q1 +; CHECK-NEXT: vmovlb.u8 q0, q0 ; CHECK-NEXT: bx lr entry: %s0s = zext <8 x i8> %s0 to <8 x i16>