diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13393,9 +13393,16 @@ return true; } +static bool isSplatShuffle(Value *V) { + if (auto *Shuf = dyn_cast(V)) + return all_equal(Shuf->getShuffleMask()); + return false; +} + /// Check if both Op1 and Op2 are shufflevector extracts of either the lower /// or upper half of the vector elements. -static bool areExtractShuffleVectors(Value *Op1, Value *Op2) { +static bool areExtractShuffleVectors(Value *Op1, Value *Op2, + bool AllowSplat = false) { auto areTypesHalfed = [](Value *FullV, Value *HalfV) { auto *FullTy = FullV->getType(); auto *HalfTy = HalfV->getType(); @@ -13410,25 +13417,39 @@ }; ArrayRef M1, M2; - Value *S1Op1, *S2Op1; + Value *S1Op1 = nullptr, *S2Op1 = nullptr; if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) || !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2)))) return false; + // If we allow splats, set S1Op1/S2Op1 to nullptr for the relavant arg so that + // it is not checked as an extract below. + if (AllowSplat && isSplatShuffle(Op1)) + S1Op1 = nullptr; + if (AllowSplat && isSplatShuffle(Op2)) + S2Op1 = nullptr; + // Check that the operands are half as wide as the result and we extract // half of the elements of the input vectors. - if (!areTypesHalfed(S1Op1, Op1) || !areTypesHalfed(S2Op1, Op2) || - !extractHalf(S1Op1, Op1) || !extractHalf(S2Op1, Op2)) + if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) || + (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2)))) return false; // Check the mask extracts either the lower or upper half of vector // elements. - int M1Start = -1; - int M2Start = -1; + int M1Start = 0; + int M2Start = 0; int NumElements = cast(Op1->getType())->getNumElements() * 2; - if (!ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start) || - !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start) || - M1Start != M2Start || (M1Start != 0 && M2Start != (NumElements / 2))) + if ((S1Op1 && + !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) || + (S2Op1 && + !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start))) + return false; + + if ((M1Start != 0 && M1Start != (NumElements / 2)) || + (M2Start != 0 && M2Start != (NumElements / 2))) + return false; + if (S1Op1 && S2Op1 && M1Start != M2Start) return false; return true; @@ -13467,12 +13488,6 @@ return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2); } -static bool isSplatShuffle(Value *V) { - if (auto *Shuf = dyn_cast(V)) - return all_equal(Shuf->getShuffleMask()); - return false; -} - /// Check if sinking \p I's operands to I's basic block is profitable, because /// the operands can be folded into a target instruction, e.g. /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2). @@ -13482,7 +13497,8 @@ switch (II->getIntrinsicID()) { case Intrinsic::aarch64_neon_smull: case Intrinsic::aarch64_neon_umull: - if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1))) { + if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1), + /*AllowSplat=*/true)) { Ops.push_back(&II->getOperandUse(0)); Ops.push_back(&II->getOperandUse(1)); return true; diff --git a/llvm/test/CodeGen/AArch64/sinksplat.ll b/llvm/test/CodeGen/AArch64/sinksplat.ll --- a/llvm/test/CodeGen/AArch64/sinksplat.ll +++ b/llvm/test/CodeGen/AArch64/sinksplat.ll @@ -371,8 +371,7 @@ ; CHECK-NEXT: smull v0.4s, v1.4h, v2.h[3] ; CHECK-NEXT: tbz w1, #0, .LBB11_2 ; CHECK-NEXT: // %bb.1: // %l1 -; CHECK-NEXT: ext v1.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: smlal v0.4s, v1.4h, v2.h[3] +; CHECK-NEXT: smlal2 v0.4s, v1.8h, v2.h[3] ; CHECK-NEXT: .LBB11_2: // %l2 ; CHECK-NEXT: ret entry: @@ -399,8 +398,7 @@ ; CHECK-NEXT: umull v0.4s, v1.4h, v2.h[3] ; CHECK-NEXT: tbz w1, #0, .LBB12_2 ; CHECK-NEXT: // %bb.1: // %l1 -; CHECK-NEXT: ext v1.16b, v1.16b, v1.16b, #8 -; CHECK-NEXT: umlal v0.4s, v1.4h, v2.h[3] +; CHECK-NEXT: umlal2 v0.4s, v1.8h, v2.h[3] ; CHECK-NEXT: .LBB12_2: // %l2 ; CHECK-NEXT: ret entry: