Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13508,33 +13508,15 @@ } } -/// Combines a dup(sext/zext) node pattern into sext/zext(dup) -/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt -static SDValue performCommonVectorExtendCombine(SDValue VectorShuffle, - SelectionDAG &DAG) { - ShuffleVectorSDNode *ShuffleNode = - dyn_cast(VectorShuffle.getNode()); - if (!ShuffleNode) - return SDValue(); - - // Ensuring the mask is zero before continuing - if (!ShuffleNode->isSplat() || ShuffleNode->getSplatIndex() != 0) - return SDValue(); - - SDValue InsertVectorElt = VectorShuffle.getOperand(0); - - if (InsertVectorElt.getOpcode() != ISD::INSERT_VECTOR_ELT) - return SDValue(); - - SDValue InsertLane = InsertVectorElt.getOperand(2); - ConstantSDNode *Constant = dyn_cast(InsertLane.getNode()); - // Ensures the insert is inserting into lane 0 - if (!Constant || Constant->getZExtValue() != 0) +static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) { + EVT VT = BV.getValueType(); + if (BV.getOpcode() != ISD::BUILD_VECTOR) return SDValue(); - SDValue Extend = InsertVectorElt.getOperand(1); + // Use the first item in the buildvector to get the size of the extend, and + // make sure it looks valid. + SDValue Extend = BV->getOperand(0); unsigned ExtendOpcode = Extend.getOpcode(); - bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND || ExtendOpcode == ISD::SIGN_EXTEND_INREG || ExtendOpcode == ISD::AssertSext; @@ -13544,30 +13526,28 @@ // Restrict valid pre-extend data type EVT PreExtendType = calculatePreExtendType(Extend); - if (PreExtendType != MVT::i8 && PreExtendType != MVT::i16 && - PreExtendType != MVT::i32) - return SDValue(); - - EVT TargetType = VectorShuffle.getValueType(); - EVT PreExtendVT = TargetType.changeVectorElementType(PreExtendType); - if (TargetType.getScalarSizeInBits() != PreExtendVT.getScalarSizeInBits() * 2) + if (PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2) return SDValue(); - SDLoc DL(VectorShuffle); - - SDValue InsertVectorNode = DAG.getNode( - InsertVectorElt.getOpcode(), DL, PreExtendVT, DAG.getUNDEF(PreExtendVT), - DAG.getAnyExtOrTrunc(Extend.getOperand(0), DL, PreExtendType), - DAG.getConstant(0, DL, MVT::i64)); - - std::vector ShuffleMask(TargetType.getVectorNumElements()); - - SDValue VectorShuffleNode = - DAG.getVectorShuffle(PreExtendVT, DL, InsertVectorNode, - DAG.getUNDEF(PreExtendVT), ShuffleMask); + // Make sure all other operands are equally extended + for (SDValue Op : drop_begin(BV->ops())) { + unsigned Opc = Op.getOpcode(); + bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG || + Opc == ISD::AssertSext; + if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType) + return SDValue(); + } - return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, - TargetType, VectorShuffleNode); + EVT PreExtendVT = VT.changeVectorElementType(PreExtendType); + EVT PreExtendLegalType = + PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType; + SDLoc DL(BV); + SmallVector NewOps; + for (SDValue Op : BV->ops()) + NewOps.push_back( + DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType)); + SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps); + return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV); } /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup)) @@ -13578,8 +13558,8 @@ if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64) return SDValue(); - SDValue Op0 = performCommonVectorExtendCombine(Mul->getOperand(0), DAG); - SDValue Op1 = performCommonVectorExtendCombine(Mul->getOperand(1), DAG); + SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG); + SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG); // Neither operands have been changed, don't make any further changes if (!Op0 && !Op1) Index: llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll +++ llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll @@ -156,10 +156,8 @@ define <8 x i16> @nonsplat_shuffleinsert(i8 %src, <8 x i8> %b) { ; CHECK-LABEL: nonsplat_shuffleinsert: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sxtb w8, w0 -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: dup v1.8h, w8 -; CHECK-NEXT: mul v0.8h, v1.8h, v0.8h +; CHECK-NEXT: dup v1.8b, w0 +; CHECK-NEXT: smull v0.8h, v1.8b, v0.8b ; CHECK-NEXT: ret entry: %in = sext i8 %src to i16 Index: llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll +++ llvm/test/CodeGen/AArch64/aarch64-matrix-umull-smull.ll @@ -201,25 +201,22 @@ ; CHECK-NEXT: b .LBB3_6 ; CHECK-NEXT: .LBB3_3: // %vector.ph ; CHECK-NEXT: and x10, x9, #0xfffffff0 +; CHECK-NEXT: dup v0.4h, w8 ; CHECK-NEXT: add x11, x2, #32 ; CHECK-NEXT: add x12, x0, #16 ; CHECK-NEXT: mov x13, x10 -; CHECK-NEXT: dup v0.4s, w8 +; CHECK-NEXT: dup v1.8h, w8 ; CHECK-NEXT: .LBB3_4: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: ldp q1, q2, [x12, #-16] +; CHECK-NEXT: ldp q2, q3, [x12, #-16] ; CHECK-NEXT: subs x13, x13, #16 ; CHECK-NEXT: add x12, x12, #32 -; CHECK-NEXT: sshll2 v3.4s, v1.8h, #0 -; CHECK-NEXT: sshll v1.4s, v1.4h, #0 -; CHECK-NEXT: sshll2 v4.4s, v2.8h, #0 -; CHECK-NEXT: sshll v2.4s, v2.4h, #0 -; CHECK-NEXT: mul v3.4s, v0.4s, v3.4s -; CHECK-NEXT: mul v1.4s, v0.4s, v1.4s -; CHECK-NEXT: mul v4.4s, v0.4s, v4.4s -; CHECK-NEXT: mul v2.4s, v0.4s, v2.4s -; CHECK-NEXT: stp q1, q3, [x11, #-32] -; CHECK-NEXT: stp q2, q4, [x11], #64 +; CHECK-NEXT: smull2 v4.4s, v1.8h, v2.8h +; CHECK-NEXT: smull v2.4s, v0.4h, v2.4h +; CHECK-NEXT: smull2 v5.4s, v1.8h, v3.8h +; CHECK-NEXT: smull v3.4s, v0.4h, v3.4h +; CHECK-NEXT: stp q2, q4, [x11, #-32] +; CHECK-NEXT: stp q3, q5, [x11], #64 ; CHECK-NEXT: b.ne .LBB3_4 ; CHECK-NEXT: // %bb.5: // %middle.block ; CHECK-NEXT: cmp x10, x9 @@ -317,25 +314,22 @@ ; CHECK-NEXT: b .LBB4_6 ; CHECK-NEXT: .LBB4_3: // %vector.ph ; CHECK-NEXT: and x10, x9, #0xfffffff0 +; CHECK-NEXT: dup v0.4h, w8 ; CHECK-NEXT: add x11, x2, #32 ; CHECK-NEXT: add x12, x0, #16 ; CHECK-NEXT: mov x13, x10 -; CHECK-NEXT: dup v0.4s, w8 +; CHECK-NEXT: dup v1.8h, w8 ; CHECK-NEXT: .LBB4_4: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: ldp q1, q2, [x12, #-16] +; CHECK-NEXT: ldp q2, q3, [x12, #-16] ; CHECK-NEXT: subs x13, x13, #16 ; CHECK-NEXT: add x12, x12, #32 -; CHECK-NEXT: ushll2 v3.4s, v1.8h, #0 -; CHECK-NEXT: ushll v1.4s, v1.4h, #0 -; CHECK-NEXT: ushll2 v4.4s, v2.8h, #0 -; CHECK-NEXT: ushll v2.4s, v2.4h, #0 -; CHECK-NEXT: mul v3.4s, v0.4s, v3.4s -; CHECK-NEXT: mul v1.4s, v0.4s, v1.4s -; CHECK-NEXT: mul v4.4s, v0.4s, v4.4s -; CHECK-NEXT: mul v2.4s, v0.4s, v2.4s -; CHECK-NEXT: stp q1, q3, [x11, #-32] -; CHECK-NEXT: stp q2, q4, [x11], #64 +; CHECK-NEXT: umull2 v4.4s, v1.8h, v2.8h +; CHECK-NEXT: umull v2.4s, v0.4h, v2.4h +; CHECK-NEXT: umull2 v5.4s, v1.8h, v3.8h +; CHECK-NEXT: umull v3.4s, v0.4h, v3.4h +; CHECK-NEXT: stp q2, q4, [x11, #-32] +; CHECK-NEXT: stp q3, q5, [x11], #64 ; CHECK-NEXT: b.ne .LBB4_4 ; CHECK-NEXT: // %bb.5: // %middle.block ; CHECK-NEXT: cmp x10, x9 @@ -435,12 +429,13 @@ ; CHECK-NEXT: mov w0, wzr ; CHECK-NEXT: ret ; CHECK-NEXT: .LBB5_4: // %vector.ph +; CHECK-NEXT: dup v1.8b, w9 ; CHECK-NEXT: and x11, x10, #0xfffffff0 -; CHECK-NEXT: add x8, x0, #8 ; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: add x8, x0, #8 +; CHECK-NEXT: movi v2.2d, #0000000000000000 ; CHECK-NEXT: mov x12, x11 -; CHECK-NEXT: movi v1.2d, #0000000000000000 -; CHECK-NEXT: dup v2.8h, w9 +; CHECK-NEXT: sshll v1.8h, v1.8b, #0 ; CHECK-NEXT: .LBB5_5: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: ldp d3, d4, [x8, #-8] @@ -448,11 +443,11 @@ ; CHECK-NEXT: add x8, x8, #16 ; CHECK-NEXT: ushll v3.8h, v3.8b, #0 ; CHECK-NEXT: ushll v4.8h, v4.8b, #0 -; CHECK-NEXT: mla v0.8h, v2.8h, v3.8h -; CHECK-NEXT: mla v1.8h, v2.8h, v4.8h +; CHECK-NEXT: mla v0.8h, v1.8h, v3.8h +; CHECK-NEXT: mla v2.8h, v1.8h, v4.8h ; CHECK-NEXT: b.ne .LBB5_5 ; CHECK-NEXT: // %bb.6: // %middle.block -; CHECK-NEXT: add v0.8h, v1.8h, v0.8h +; CHECK-NEXT: add v0.8h, v2.8h, v0.8h ; CHECK-NEXT: cmp x11, x10 ; CHECK-NEXT: addv h0, v0.8h ; CHECK-NEXT: fmov w8, s0