diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1128,12 +1128,13 @@ setOperationAction(ISD::SMIN, VT, Custom); } - // AArch64 doesn't have MUL.2d: - setOperationAction(ISD::MUL, MVT::v2i64, Expand); // Custom handling for some quad-vector types to detect MULL. setOperationAction(ISD::MUL, MVT::v8i16, Custom); setOperationAction(ISD::MUL, MVT::v4i32, Custom); setOperationAction(ISD::MUL, MVT::v2i64, Custom); + setOperationAction(ISD::MUL, MVT::v4i16, Custom); + setOperationAction(ISD::MUL, MVT::v2i32, Custom); + setOperationAction(ISD::MUL, MVT::v1i64, Custom); // Saturates for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, @@ -4592,24 +4593,44 @@ EVT VT = Op.getValueType(); // If SVE is available then i64 vector multiplications can also be made legal. - bool OverrideNEON = - VT == MVT::v1i64 || Subtarget->forceStreamingCompatibleSVE(); + bool OverrideNEON = Subtarget->forceStreamingCompatibleSVE(); if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON)) return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED); - // Multiplications are only custom-lowered for 128-bit vectors so that - // VMULL can be detected. Otherwise v2i64 multiplications are not legal. - assert(VT.is128BitVector() && VT.isInteger() && + // Multiplications are only custom-lowered for 128-bit and 64-bit vectors so + // that VMULL can be detected. Otherwise v2i64 multiplications are not legal. + assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() && "unexpected type for custom-lowering ISD::MUL"); SDNode *N0 = Op.getOperand(0).getNode(); SDNode *N1 = Op.getOperand(1).getNode(); bool isMLA = false; + EVT OVT = VT; + if (VT.is64BitVector()) { + if (N0->getOpcode() == ISD::EXTRACT_SUBVECTOR && + isNullConstant(N0->getOperand(1)) && + N1->getOpcode() == ISD::EXTRACT_SUBVECTOR && + isNullConstant(N1->getOperand(1))) { + N0 = N0->getOperand(0).getNode(); + N1 = N1->getOperand(0).getNode(); + VT = N0->getValueType(0); + } else { + if (VT == MVT::v1i64) { + if (Subtarget->hasSVE()) + return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED); + // Fall through to expand this. It is not legal. + return SDValue(); + } else + // Other vector multiplications are legal. + return Op; + } + } + SDLoc DL(Op); unsigned NewOpc = selectUmullSmull(N0, N1, DAG, DL, isMLA); if (!NewOpc) { - if (VT == MVT::v2i64) { + if (VT.getVectorElementType() == MVT::i64) { // If SVE is available then i64 vector multiplications can also be made // legal. if (Subtarget->hasSVE()) @@ -4629,7 +4650,9 @@ assert(Op0.getValueType().is64BitVector() && Op1.getValueType().is64BitVector() && "unexpected types for extended operands to VMULL"); - return DAG.getNode(NewOpc, DL, VT, Op0, Op1); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OVT, + DAG.getNode(NewOpc, DL, VT, Op0, Op1), + DAG.getConstant(0, DL, MVT::i64)); } // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during // isel lowering to take advantage of no-stall back to back s/umul + s/umla. @@ -4637,11 +4660,14 @@ SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG); SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG); EVT Op1VT = Op1.getValueType(); - return DAG.getNode(N0->getOpcode(), DL, VT, - DAG.getNode(NewOpc, DL, VT, - DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1), - DAG.getNode(NewOpc, DL, VT, - DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)); + return DAG.getNode( + ISD::EXTRACT_SUBVECTOR, DL, OVT, + DAG.getNode(N0->getOpcode(), DL, VT, + DAG.getNode(NewOpc, DL, VT, + DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1), + DAG.getNode(NewOpc, DL, VT, + DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)), + DAG.getConstant(0, DL, MVT::i64)); } static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT, diff --git a/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll b/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll --- a/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll +++ b/llvm/test/CodeGen/AArch64/neon-extadd-extract.ll @@ -120,9 +120,8 @@ define <4 x i16> @mulls_v8i8_0(<8 x i8> %s0, <8 x i8> %s1) { ; CHECK-LABEL: mulls_v8i8_0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.8h, v1.8b, #0 -; CHECK-NEXT: mul v0.4h, v0.4h, v1.4h +; CHECK-NEXT: smull v0.8h, v0.8b, v1.8b +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret entry: %s0s = sext <8 x i8> %s0 to <8 x i16> @@ -149,9 +148,8 @@ define <4 x i16> @mullu_v8i8_0(<8 x i8> %s0, <8 x i8> %s1) { ; CHECK-LABEL: mullu_v8i8_0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: mul v0.4h, v0.4h, v1.4h +; CHECK-NEXT: umull v0.8h, v0.8b, v1.8b +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret entry: %s0s = zext <8 x i8> %s0 to <8 x i16> @@ -294,9 +292,8 @@ define <2 x i32> @mulls_v4i16_0(<4 x i16> %s0, <4 x i16> %s1) { ; CHECK-LABEL: mulls_v4i16_0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.4s, v0.4h, #0 -; CHECK-NEXT: sshll v1.4s, v1.4h, #0 -; CHECK-NEXT: mul v0.2s, v0.2s, v1.2s +; CHECK-NEXT: smull v0.4s, v0.4h, v1.4h +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret entry: %s0s = sext <4 x i16> %s0 to <4 x i32> @@ -323,9 +320,8 @@ define <2 x i32> @mullu_v4i16_0(<4 x i16> %s0, <4 x i16> %s1) { ; CHECK-LABEL: mullu_v4i16_0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.4s, v0.4h, #0 -; CHECK-NEXT: ushll v1.4s, v1.4h, #0 -; CHECK-NEXT: mul v0.2s, v0.2s, v1.2s +; CHECK-NEXT: umull v0.4s, v0.4h, v1.4h +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret entry: %s0s = zext <4 x i16> %s0 to <4 x i32> @@ -468,12 +464,8 @@ define <1 x i64> @mulls_v2i32_0(<2 x i32> %s0, <2 x i32> %s1) { ; CHECK-LABEL: mulls_v2i32_0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: fmov x9, d0 -; CHECK-NEXT: fmov x8, d1 -; CHECK-NEXT: smull x8, w9, w8 -; CHECK-NEXT: fmov d0, x8 +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret entry: %s0s = sext <2 x i32> %s0 to <2 x i64> @@ -504,12 +496,8 @@ define <1 x i64> @mullu_v2i32_0(<2 x i32> %s0, <2 x i32> %s1) { ; CHECK-LABEL: mullu_v2i32_0: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.2d, v0.2s, #0 -; CHECK-NEXT: ushll v1.2d, v1.2s, #0 -; CHECK-NEXT: fmov x9, d0 -; CHECK-NEXT: fmov x8, d1 -; CHECK-NEXT: umull x8, w9, w8 -; CHECK-NEXT: fmov d0, x8 +; CHECK-NEXT: umull v0.2d, v0.2s, v1.2s +; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret entry: %s0s = zext <2 x i32> %s0 to <2 x i64>