Index: llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -99,6 +99,24 @@ return GFCstAndRegMatch(FPValReg); } +struct GFCstOrSplatGFCstMatch { + Optional &FPValReg; + GFCstOrSplatGFCstMatch(Optional &FPValReg) + : FPValReg(FPValReg) {} + bool match(const MachineRegisterInfo &MRI, Register Reg) { + if ((FPValReg = getFConstantSplat(Reg, MRI))) + return true; + if ((FPValReg = getFConstantVRegValWithLookThrough(Reg, MRI))) + return true; + return false; + }; +}; + +inline GFCstOrSplatGFCstMatch +m_GFCstOrSplat(Optional &FPValReg) { + return GFCstOrSplatGFCstMatch(FPValReg); +} + /// Matcher for a specific constant value. struct SpecificConstantMatch { int64_t RequestedVal; Index: llvm/include/llvm/CodeGen/GlobalISel/Utils.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -341,15 +341,23 @@ Optional getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI); +/// Returns a floating point scalar constant of a build vector splat if it +/// exists. When \p AllowUndef == true some elements can be undef but not all. +Optional getFConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef = true); + /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef. bool isBuildVectorAllZeros(const MachineInstr &MI, - const MachineRegisterInfo &MRI); + const MachineRegisterInfo &MRI, + bool AllowUndef = false); /// Return true if the specified instruction is a G_BUILD_VECTOR or /// G_BUILD_VECTOR_TRUNC where all of the elements are ~0 or undef. bool isBuildVectorAllOnes(const MachineInstr &MI, - const MachineRegisterInfo &MRI); + const MachineRegisterInfo &MRI, + bool AllowUndef = false); /// \returns a value when \p MI is a vector splat. The splat can be either a /// Register or a constant. Index: llvm/lib/CodeGen/GlobalISel/Utils.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -911,53 +911,82 @@ Opcode == TargetOpcode::G_BUILD_VECTOR_TRUNC; } -// TODO: Handle mixed undef elements. -static bool isBuildVectorConstantSplat(const MachineInstr &MI, - const MachineRegisterInfo &MRI, - int64_t SplatValue) { - if (!isBuildVectorOp(MI.getOpcode())) - return false; +namespace { - const unsigned NumOps = MI.getNumOperands(); - for (unsigned I = 1; I != NumOps; ++I) { - Register Element = MI.getOperand(I).getReg(); - if (!mi_match(Element, MRI, m_SpecificICst(SplatValue))) - return false; +Optional getAnyConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef) { + MachineInstr *MI = getDefIgnoringCopies(VReg, MRI); + + if (!isBuildVectorOp(MI->getOpcode())) + return None; + + Optional SplatValAndReg = None; + for (unsigned i = 1, NumOps = MI->getNumOperands(); i != NumOps; ++i) { + + Register Element = MI->getOperand(i).getReg(); + auto ElementValAndReg = + getAnyConstantVRegValWithLookThrough(Element, MRI, true, true); + + // If AllowUndef, treat undef as value that will result in a constant splat. + if (!ElementValAndReg) { + if (AllowUndef && + MRI.getVRegDef(Element)->getOpcode() == TargetOpcode::G_IMPLICIT_DEF) + continue; + return None; + } + + // Record splat value + if (!SplatValAndReg) + SplatValAndReg = ElementValAndReg; + + // Different contant then the one already recorded, not a constant splat. + if (SplatValAndReg->Value != ElementValAndReg->Value) + return None; } - return true; + return SplatValAndReg; } +bool isBuildVectorConstantSplat(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + int64_t SplatValue, bool AllowUndef) { + if (auto SplatValAndReg = + getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, AllowUndef)) + return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue)); + return false; +} + +} // end anonymous namespace + Optional llvm::getBuildVectorConstantSplat(const MachineInstr &MI, const MachineRegisterInfo &MRI) { - if (!isBuildVectorOp(MI.getOpcode())) - return None; - - const unsigned NumOps = MI.getNumOperands(); - Optional Scalar; - for (unsigned I = 1; I != NumOps; ++I) { - Register Element = MI.getOperand(I).getReg(); - int64_t ElementValue; - if (!mi_match(Element, MRI, m_ICst(ElementValue))) - return None; - if (!Scalar) - Scalar = ElementValue; - else if (*Scalar != ElementValue) - return None; - } + if (auto SplatValAndReg = + getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false)) + return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI); + return None; +} - return Scalar; +Optional llvm::getFConstantSplat(Register VReg, + const MachineRegisterInfo &MRI, + bool AllowUndef) { + // Match as any constant and get APFloat from Reg. Allows nan splat matching. + if (auto SplatValAndReg = getAnyConstantSplat(VReg, MRI, AllowUndef)) + return getFConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI); + return None; } bool llvm::isBuildVectorAllZeros(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { - return isBuildVectorConstantSplat(MI, MRI, 0); + const MachineRegisterInfo &MRI, + bool AllowUndef) { + return isBuildVectorConstantSplat(MI, MRI, 0, AllowUndef); } bool llvm::isBuildVectorAllOnes(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { - return isBuildVectorConstantSplat(MI, MRI, -1); + const MachineRegisterInfo &MRI, + bool AllowUndef) { + return isBuildVectorConstantSplat(MI, MRI, -1, AllowUndef); } Optional llvm::getVectorSplat(const MachineInstr &MI, Index: llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp =================================================================== --- llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -574,6 +574,57 @@ EXPECT_EQ(FPOne, FValReg->VReg); } +TEST_F(AArch64GISelMITest, MatchConstantSplat) { + setUp(); + if (!TM) + return; + + LLT s64 = LLT::scalar(64); + LLT v4s64 = LLT::fixed_vector(4, 64); + + Register FPOne = B.buildFConstant(s64, 1.0).getReg(0); + Register FPZero = B.buildFConstant(s64, 0.0).getReg(0); + Register Undef = B.buildUndef(s64).getReg(0); + Optional FValReg; + + // GFCstOrSplatGFCstMatch allows undef as part of splat. Undef often comes + // from padding to legalize into available operation and then ignore added + // elements e.g. v3s64 to v4s64. + + EXPECT_TRUE(mi_match(FPZero, *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + EXPECT_FALSE(mi_match(Undef, *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto ZeroSplat = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPZero}); + EXPECT_TRUE( + mi_match(ZeroSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + auto ZeroUndef = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Undef}); + EXPECT_TRUE( + mi_match(ZeroUndef.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + EXPECT_EQ(FPZero, FValReg->VReg); + + // All undefs are not constant splat. + auto UndefSplat = B.buildBuildVector(v4s64, {Undef, Undef, Undef, Undef}); + EXPECT_FALSE( + mi_match(UndefSplat.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto ZeroOne = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, FPOne}); + EXPECT_FALSE( + mi_match(ZeroOne.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); + + auto NonConstantSplat = + B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); + EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, + GFCstOrSplatGFCstMatch(FValReg))); + + auto Mixed = B.buildBuildVector(v4s64, {FPZero, FPZero, FPZero, Copies[0]}); + EXPECT_FALSE( + mi_match(Mixed.getReg(0), *MRI, GFCstOrSplatGFCstMatch(FValReg))); +} + TEST_F(AArch64GISelMITest, MatchNeg) { setUp(); if (!TM)