Index: llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -97,6 +97,28 @@ }; inline FCstRegMatch m_FCst(Register &Reg) { return FCstRegMatch(Reg); } +template +struct CstOrSplatCstRegMatch { + Register &CR; + CstOrSplatCstRegMatch(Register &C) : CR(C) {} + bool match(const MachineRegisterInfo &MRI, Register Reg) { + if (MRI.getType(Reg).isVector()) { + if (auto MaybeCst = getVectorConstantSplatUndef( + Reg, MRI, AllowUndef, HandleFConstants, HandleIConstants)) { + CR = *MaybeCst; + return true; + } + return false; + } + // Try to match scalar constant. + return CstRegMatch(CR).match(MRI, Reg); + } +}; + +struct FCstOrSplatFCstRegMatch : CstOrSplatCstRegMatch { + FCstOrSplatFCstRegMatch(Register &C) : CstOrSplatCstRegMatch(C) {} +}; + /// Matcher for a specific constant value. struct SpecificConstantMatch { int64_t RequestedVal; Index: llvm/include/llvm/CodeGen/GlobalISel/Utils.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -334,6 +334,14 @@ Optional getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI); +/// Returns a scalar constant of a G_BUILD_VECTOR splat if it exists. +/// When \p AllowUndef == true some elements can be undef but not all. +Optional getVectorConstantSplatUndef(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef, + bool MatchFConstant, + bool MatchIConstant); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef. bool isBuildVectorAllZeros(const MachineInstr &MI, Index: llvm/lib/CodeGen/GlobalISel/Utils.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -938,6 +938,56 @@ return RegOrConstant(Reg); } +Optional llvm::getVectorConstantSplatUndef( + Register VReg, const MachineRegisterInfo &MRI, bool AllowUndef, + bool MatchFConstant, bool MatchIConstant) { + // Look through copies. + MachineInstr *MI; + while ((MI = MRI.getVRegDef(VReg)) && !isBuildVectorOp(MI->getOpcode())) { + switch (MI->getOpcode()) { + case TargetOpcode::COPY: + VReg = MI->getOperand(1).getReg(); + if (Register::isPhysicalRegister(VReg)) + return None; + break; + default: + return None; + } + } + + if (!isBuildVectorOp(MI->getOpcode())) + return None; + + Optional ValReg = None; + Optional SplatVal = None; + for (unsigned I = 1, NumOps = MI->getNumOperands(); I != NumOps; ++I) { + + Register Element = MI->getOperand(I).getReg(); + auto ValAndReg = getConstantVRegValWithLookThrough( + Element, MRI, true, MatchFConstant, MatchIConstant, true); + + // If AllowUndef, treat undef as value that will result in a constant splat. + if (!ValAndReg) { + if (AllowUndef && + MRI.getVRegDef(Element)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF) + continue; + return None; + } + + // Record splat value + if (!SplatVal) { + ValReg = ValAndReg->VReg; + SplatVal = ValAndReg->Value; + } + + // Different contant then the one already recorded, not a constant splat. + if (SplatVal != ValAndReg->Value) + return None; + } + + return ValReg; +} + bool llvm::matchUnaryPredicate( const MachineRegisterInfo &MRI, Register Reg, std::function Match, bool AllowUndefs) { Index: llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp =================================================================== --- llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -584,6 +584,61 @@ EXPECT_EQ(FPOne, Reg); } +TEST_F(AArch64GISelMITest, MatchConstantSplat) { + setUp(); + if (!TM) + return; + + LLT s64 = LLT::scalar(64); + LLT v4s64 = LLT::vector(4, 64); + + Register FPOne = B.buildFConstant(s64, 1.0).getReg(0); + Register FPZero = B.buildFConstant(s64, 0.0).getReg(0); + Register Undef = B.buildUndef(s64).getReg(0); + Register Reg; + + // FCstOrSplatFCstRegMatch allows undef as part of splat. Undef often comes + // from padding to legalize into available operation and then ignore added + // elements e.g. v3s64 to v4s64. + + EXPECT_TRUE(mi_match(FPZero, *MRI, FCstOrSplatFCstRegMatch(Reg))); + EXPECT_EQ(FPZero, Reg); + + EXPECT_FALSE(mi_match(Undef, *MRI, FCstOrSplatFCstRegMatch(Reg))); + + auto ZeroSplat = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPZero}); + EXPECT_TRUE( + mi_match(ZeroSplat.getReg(0), *MRI, FCstOrSplatFCstRegMatch(Reg))); + EXPECT_EQ(FPZero, Reg); + + auto ZeroUndef = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Undef}); + EXPECT_TRUE( + mi_match(ZeroUndef.getReg(0), *MRI, FCstOrSplatFCstRegMatch(Reg))); + EXPECT_EQ(FPZero, Reg); + + // ZeroUndef fails splat match if we don't allow undef. + EXPECT_FALSE(mi_match( + ZeroUndef.getReg(0), *MRI, + CstOrSplatCstRegMatch(Reg))); + + // All undefs are not constant splat. + auto UndefSplat = B.buildBuildVector(v4s64, {Undef, Undef, Undef, Undef}); + EXPECT_FALSE( + mi_match(UndefSplat.getReg(0), *MRI, FCstOrSplatFCstRegMatch(Reg))); + + auto ZeroOne = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPOne}); + EXPECT_FALSE(mi_match(ZeroOne.getReg(0), *MRI, FCstOrSplatFCstRegMatch(Reg))); + + auto NonConstantSplat = + B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); + EXPECT_FALSE( + mi_match(NonConstantSplat.getReg(0), *MRI, FCstOrSplatFCstRegMatch(Reg))); + + auto Mixed = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Copies[0]}); + EXPECT_FALSE(mi_match(Mixed.getReg(0), *MRI, FCstOrSplatFCstRegMatch(Reg))); +} + TEST_F(AArch64GISelMITest, MatchNeg) { setUp(); if (!TM)