diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4426,18 +4426,25 @@ } static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) { + EVT VT = N.getValueType(); + assert(VT.is128BitVector() && "Unexpected vector MULL size"); + + unsigned NumElts = VT.getVectorNumElements(); + unsigned OrigEltSize = VT.getScalarSizeInBits(); + unsigned EltSize = OrigEltSize / 2; + MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts); + + APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize); + if (DAG.MaskedValueIsZero(N, HiBits)) + return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N); + if (ISD::isExtOpcode(N.getOpcode())) return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG, - N.getOperand(0).getValueType(), - N.getValueType(), + N.getOperand(0).getValueType(), VT, N.getOpcode()); assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR"); - EVT VT = N.getValueType(); SDLoc dl(N); - unsigned EltSize = VT.getScalarSizeInBits() / 2; - unsigned NumElts = VT.getVectorNumElements(); - MVT TruncVT = MVT::getIntegerVT(EltSize); SmallVector Ops; for (unsigned i = 0; i != NumElts; ++i) { ConstantSDNode *C = cast(N.getOperand(i)); @@ -4446,7 +4453,7 @@ // The values are implicitly truncated so sext vs. zext doesn't matter. Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32)); } - return DAG.getBuildVector(MVT::getVectorVT(TruncVT, NumElts), dl, Ops); + return DAG.getBuildVector(TruncVT, dl, Ops); } static bool isSignExtended(SDValue N, SelectionDAG &DAG) { @@ -4588,31 +4595,8 @@ EVT VT = N0.getValueType(); APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(), VT.getScalarSizeInBits() / 2); - if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask)) { - EVT HalfVT; - switch (VT.getSimpleVT().SimpleTy) { - case MVT::v2i64: - HalfVT = MVT::v2i32; - break; - case MVT::v4i32: - HalfVT = MVT::v4i16; - break; - case MVT::v8i16: - HalfVT = MVT::v8i8; - break; - default: - return 0; - } - // Truncate and then extend the result. - SDValue NewExt = - DAG.getNode(ISD::TRUNCATE, DL, HalfVT, IsN0ZExt ? N1 : N0); - NewExt = DAG.getZExtOrTrunc(NewExt, DL, VT); - if (IsN0ZExt) - N1 = NewExt; - else - N0 = NewExt; + if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask)) return AArch64ISD::UMULL; - } } if (!IsN1SExt && !IsN1ZExt)