diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1035,14 +1035,19 @@ if (!MI) return std::nullopt; - if (!isBuildVectorOp(MI->getOpcode())) + bool isConcatVectorsOp = MI->getOpcode() == TargetOpcode::G_CONCAT_VECTORS; + if (!isBuildVectorOp(MI->getOpcode()) && !isConcatVectorsOp) return std::nullopt; std::optional SplatValAndReg; for (MachineOperand &Op : MI->uses()) { Register Element = Op.getReg(); + // If we have a G_CONCAT_VECTOR, we recursively look into the + // vectors that we're concatenating to see if they're splats. auto ElementValAndReg = - getAnyConstantVRegValWithLookThrough(Element, MRI, true, true); + isConcatVectorsOp + ? getAnyConstantSplat(Element, MRI, AllowUndef) + : getAnyConstantVRegValWithLookThrough(Element, MRI, true, true); // If AllowUndef, treat undef as value that will result in a constant splat. if (!ElementValAndReg) { diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -57,6 +57,7 @@ return; LLT s64 = LLT::scalar(64); + LLT v2s64 = LLT::fixed_vector(2, s64); LLT v4s64 = LLT::fixed_vector(4, s64); MachineInstrBuilder FortyTwoSplat = @@ -68,6 +69,11 @@ MachineInstrBuilder NonConstantSplat = B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, m_ICstOrSplat(Cst))); + + auto ICst = B.buildConstant(s64, 15).getReg(0); + auto SmallSplat = B.buildBuildVector(v2s64, {ICst, ICst}).getReg(0); + auto LargeSplat = B.buildConcatVectors(v4s64, {SmallSplat, SmallSplat}); + EXPECT_TRUE(mi_match(LargeSplat.getReg(0), *MRI, m_ICstOrSplat(Cst))); } TEST_F(AArch64GISelMITest, MachineInstrPtrBind) { @@ -718,6 +724,7 @@ return; LLT s64 = LLT::scalar(64); + LLT v2s64 = LLT::fixed_vector(2, 64); LLT v4s64 = LLT::fixed_vector(4, 64); Register FPOne = B.buildFConstant(s64, 1.0).getReg(0); @@ -761,6 +768,44 @@ auto Mixed = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Copies[0]}); EXPECT_FALSE( mi_match(Mixed.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + // Look through G_CONCAT_VECTORS. + auto SmallZeroSplat = B.buildBuildVector(v2s64, {FPZero, FPZero}).getReg(0); + auto LargeZeroSplat = + B.buildConcatVectors(v4s64, {SmallZeroSplat, SmallZeroSplat}); + EXPECT_TRUE(mi_match(LargeZeroSplat.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + auto SmallZeroSplat2 = B.buildBuildVector(v2s64, {FPZero, FPZero}).getReg(0); + auto SmallZeroSplat3 = B.buildCopy(v2s64, SmallZeroSplat).getReg(0); + auto LargeZeroSplat2 = + B.buildConcatVectors(v4s64, {SmallZeroSplat2, SmallZeroSplat3}); + EXPECT_TRUE(mi_match(LargeZeroSplat2.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + // Not all G_CONCAT_VECTORS are splats. + auto SmallOneSplat = B.buildBuildVector(v2s64, {FPOne, FPOne}).getReg(0); + auto LargeMixedSplat = + B.buildConcatVectors(v4s64, {SmallZeroSplat, SmallOneSplat}); + EXPECT_FALSE(mi_match(LargeMixedSplat.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + auto SmallMixedSplat = B.buildBuildVector(v2s64, {FPOne, FPZero}).getReg(0); + auto LargeSplat = + B.buildConcatVectors(v4s64, {SmallMixedSplat, SmallMixedSplat}); + EXPECT_FALSE( + mi_match(LargeSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto SmallUndefSplat = B.buildBuildVector(v2s64, {Undef, Undef}).getReg(0); + auto LargeUndefSplat = + B.buildConcatVectors(v4s64, {SmallUndefSplat, SmallUndefSplat}); + EXPECT_FALSE(mi_match(LargeUndefSplat.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + auto UndefVec = B.buildUndef(v2s64).getReg(0); + auto LargeUndefSplat2 = B.buildConcatVectors(v4s64, {UndefVec, UndefVec}); + EXPECT_FALSE(mi_match(LargeUndefSplat2.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); } TEST_F(AArch64GISelMITest, MatchNeg) {