diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13629,15 +13629,17 @@ } } -/// Combines a buildvector(sext/zext) node pattern into sext/zext(buildvector) -/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt -static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) { +/// Combines a buildvector(sext/zext) or shuffle(sext/zext, undef) node pattern +/// into sext/zext(buildvector) or sext/zext(shuffle) making use of the vector +/// SExt/ZExt rather than the scalar SExt/ZExt +static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) { EVT VT = BV.getValueType(); - if (BV.getOpcode() != ISD::BUILD_VECTOR) + if (BV.getOpcode() != ISD::BUILD_VECTOR && + BV.getOpcode() != ISD::VECTOR_SHUFFLE) return SDValue(); - // Use the first item in the buildvector to get the size of the extend, and - // make sure it looks valid. + // Use the first item in the buildvector/shuffle to get the size of the + // extend, and make sure it looks valid. SDValue Extend = BV->getOperand(0); unsigned ExtendOpcode = Extend.getOpcode(); bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND || @@ -13646,15 +13648,22 @@ if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND && ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND) return SDValue(); + // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure + // calculatePreExtendType will work without issue. + if (BV.getOpcode() == ISD::VECTOR_SHUFFLE && + ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND) + return SDValue(); // Restrict valid pre-extend data type EVT PreExtendType = calculatePreExtendType(Extend); if (PreExtendType == MVT::Other || - PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2) + PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2) return SDValue(); // Make sure all other operands are equally extended for (SDValue Op : drop_begin(BV->ops())) { + if (Op.isUndef()) + continue; unsigned Opc = Op.getOpcode(); bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG || Opc == ISD::AssertSext; @@ -13662,15 +13671,26 @@ return SDValue(); } - EVT PreExtendVT = VT.changeVectorElementType(PreExtendType); - EVT PreExtendLegalType = - PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType; + SDValue NBV; SDLoc DL(BV); - SmallVector NewOps; - for (SDValue Op : BV->ops()) - NewOps.push_back( - DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType)); - SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps); + if (BV.getOpcode() == ISD::BUILD_VECTOR) { + EVT PreExtendVT = VT.changeVectorElementType(PreExtendType); + EVT PreExtendLegalType = + PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType; + SmallVector NewOps; + for (SDValue Op : BV->ops()) + NewOps.push_back(Op.isUndef() ? DAG.getUNDEF(PreExtendLegalType) + : DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, + PreExtendLegalType)); + NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps); + } else { // BV.getOpcode() == ISD::VECTOR_SHUFFLE + EVT PreExtendVT = VT.changeVectorElementType(PreExtendType.getScalarType()); + NBV = DAG.getVectorShuffle(PreExtendVT, DL, BV.getOperand(0).getOperand(0), + BV.getOperand(1).isUndef() + ? DAG.getUNDEF(PreExtendVT) + : BV.getOperand(1).getOperand(0), + cast(BV)->getMask()); + } return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV); } @@ -13682,8 +13702,8 @@ if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64) return SDValue(); - SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG); - SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG); + SDValue Op0 = performBuildShuffleExtendCombine(Mul->getOperand(0), DAG); + SDValue Op1 = performBuildShuffleExtendCombine(Mul->getOperand(1), DAG); // Neither operands have been changed, don't make any further changes if (!Op0 && !Op1) diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll --- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll @@ -245,9 +245,8 @@ define <8 x i16> @missing_insert(<8 x i8> %b) { ; CHECK-LABEL: missing_insert: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #4 -; CHECK-NEXT: mul v0.8h, v1.8h, v0.8h +; CHECK-NEXT: ext v1.8b, v0.8b, v0.8b, #2 +; CHECK-NEXT: smull v0.8h, v1.8b, v0.8b ; CHECK-NEXT: ret entry: %ext.b = sext <8 x i8> %b to <8 x i16> @@ -259,11 +258,8 @@ define <8 x i16> @shufsext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) { ; CHECK-LABEL: shufsext_v8i8_v8i16: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.8h, v1.8b, #0 -; CHECK-NEXT: rev64 v0.8h, v0.8h -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h +; CHECK-NEXT: rev64 v0.8b, v0.8b +; CHECK-NEXT: smull v0.8h, v0.8b, v1.8b ; CHECK-NEXT: ret entry: %in = sext <8 x i8> %src to <8 x i16> @@ -276,17 +272,8 @@ define <2 x i64> @shufsext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) { ; CHECK-LABEL: shufsext_v2i32_v2i64: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: fmov x9, d1 -; CHECK-NEXT: mov x8, v1.d[1] -; CHECK-NEXT: fmov x10, d0 -; CHECK-NEXT: mov x11, v0.d[1] -; CHECK-NEXT: mul x9, x10, x9 -; CHECK-NEXT: mul x8, x11, x8 -; CHECK-NEXT: fmov d0, x9 -; CHECK-NEXT: mov v0.d[1], x8 +; CHECK-NEXT: rev64 v0.2s, v0.2s +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEXT: ret entry: %in = sext <2 x i32> %src to <2 x i64> @@ -299,11 +286,8 @@ define <8 x i16> @shufzext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) { ; CHECK-LABEL: shufzext_v8i8_v8i16: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: rev64 v0.8h, v0.8h -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h +; CHECK-NEXT: rev64 v0.8b, v0.8b +; CHECK-NEXT: umull v0.8h, v0.8b, v1.8b ; CHECK-NEXT: ret entry: %in = zext <8 x i8> %src to <8 x i16> @@ -316,17 +300,8 @@ define <2 x i64> @shufzext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) { ; CHECK-LABEL: shufzext_v2i32_v2i64: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: fmov x9, d1 -; CHECK-NEXT: mov x8, v1.d[1] -; CHECK-NEXT: fmov x10, d0 -; CHECK-NEXT: mov x11, v0.d[1] -; CHECK-NEXT: mul x9, x10, x9 -; CHECK-NEXT: mul x8, x11, x8 -; CHECK-NEXT: fmov d0, x9 -; CHECK-NEXT: mov v0.d[1], x8 +; CHECK-NEXT: rev64 v0.2s, v0.2s +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEXT: ret entry: %in = sext <2 x i32> %src to <2 x i64> @@ -339,11 +314,8 @@ define <8 x i16> @shufzext_v8i8_v8i16_twoin(<8 x i8> %src1, <8 x i8> %src2, <8 x i8> %b) { ; CHECK-LABEL: shufzext_v8i8_v8i16_twoin: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: trn1 v0.8h, v0.8h, v1.8h -; CHECK-NEXT: ushll v1.8h, v2.8b, #0 -; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h +; CHECK-NEXT: trn1 v0.8b, v0.8b, v1.8b +; CHECK-NEXT: umull v0.8h, v0.8b, v2.8b ; CHECK-NEXT: ret entry: %in1 = zext <8 x i8> %src1 to <8 x i16>