Index: llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -73,13 +73,13 @@ inline ConstantMatch m_ICst(int64_t &Cst) { return ConstantMatch(Cst); } -struct ICstRegMatch { +template struct CstRegMatch { Register &CR; - ICstRegMatch(Register &C) : CR(C) {} + CstRegMatch(Register &C) : CR(C) {} bool match(const MachineRegisterInfo &MRI, Register Reg) { if (auto MaybeCst = getConstantVRegValWithLookThrough( - Reg, MRI, /*LookThroughInstrs*/ true, - /*HandleFConstants*/ false)) { + Reg, MRI, /*LookThroughInstrs*/ true, HandleFConstants, + HandleIConstants)) { CR = MaybeCst->VReg; return true; } @@ -87,8 +87,16 @@ } }; +struct ICstRegMatch : CstRegMatch { + ICstRegMatch(Register &C) : CstRegMatch(C) {} +}; inline ICstRegMatch m_ICst(Register &Reg) { return ICstRegMatch(Reg); } +struct FCstRegMatch : CstRegMatch { + FCstRegMatch(Register &C) : CstRegMatch(C) {} +}; +inline FCstRegMatch m_FCst(Register &Reg) { return FCstRegMatch(Reg); } + /// Matcher for a specific constant value. struct SpecificConstantMatch { int64_t RequestedVal; Index: llvm/include/llvm/CodeGen/GlobalISel/Utils.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -183,12 +183,14 @@ /// When \p LookThroughInstrs == false this function behaves like /// getConstantVRegVal. /// When \p HandleFConstants == false the function bails on G_FCONSTANTs. +/// When \p HandleIConstants == false the function bails on G_CONSTANTs. /// When \p LookThroughAnyExt == true the function treats G_ANYEXT same as /// G_SEXT. Optional getConstantVRegValWithLookThrough(Register VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs = true, bool HandleFConstants = true, + bool HandleIConstants = true, bool LookThroughAnyExt = false); const ConstantInt *getConstantIntVRegVal(Register VReg, const MachineRegisterInfo &MRI); Index: llvm/lib/CodeGen/GlobalISel/Utils.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -288,11 +288,11 @@ Optional llvm::getConstantVRegValWithLookThrough( Register VReg, const MachineRegisterInfo &MRI, bool LookThroughInstrs, - bool HandleFConstant, bool LookThroughAnyExt) { + bool HandleFConstant, bool HandleIConstant, bool LookThroughAnyExt) { SmallVector, 4> SeenOpcodes; MachineInstr *MI; - auto IsConstantOpcode = [HandleFConstant](unsigned Opcode) { - return Opcode == TargetOpcode::G_CONSTANT || + auto IsConstantOpcode = [HandleFConstant, HandleIConstant](unsigned Opcode) { + return (HandleIConstant && Opcode == TargetOpcode::G_CONSTANT) || (HandleFConstant && Opcode == TargetOpcode::G_FCONSTANT); }; auto GetImmediateValue = [HandleFConstant, Index: llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp @@ -609,10 +609,10 @@ MachineBasicBlock *BB = MI.getParent(); auto ConstSrc1 = - getConstantVRegValWithLookThrough(Src1, *MRI, true, true, true); + getConstantVRegValWithLookThrough(Src1, *MRI, true, true, true, true); if (ConstSrc1) { auto ConstSrc0 = - getConstantVRegValWithLookThrough(Src0, *MRI, true, true, true); + getConstantVRegValWithLookThrough(Src0, *MRI, true, true, true, true); if (ConstSrc0) { const int64_t K0 = ConstSrc0->Value.getSExtValue(); const int64_t K1 = ConstSrc1->Value.getSExtValue(); Index: llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp =================================================================== --- llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -555,6 +555,35 @@ EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_AllOnesInt())); } +TEST_F(AArch64GISelMITest, MatchFPOrIntConstExplicitly) { + setUp(); + if (!TM) + return; + + Register IntOne = B.buildConstant(LLT::scalar(64), 1).getReg(0); + Register FPOne = B.buildFConstant(LLT::scalar(64), 1.0).getReg(0); + Register Reg; + + EXPECT_TRUE(mi_match(IntOne, *MRI, m_ICst(Reg))); + EXPECT_EQ(IntOne, Reg); + EXPECT_FALSE(mi_match(IntOne, *MRI, m_FCst(Reg))); + + EXPECT_FALSE(mi_match(FPOne, *MRI, m_ICst(Reg))); + EXPECT_TRUE(mi_match(FPOne, *MRI, m_FCst(Reg))); + EXPECT_EQ(FPOne, Reg); + + const bool HandleFConstants = true; + const bool HandleIConstants = true; + // Match any constant + EXPECT_TRUE(mi_match(IntOne, *MRI, + CstRegMatch(Reg))); + EXPECT_EQ(IntOne, Reg); + + EXPECT_TRUE(mi_match(FPOne, *MRI, + CstRegMatch(Reg))); + EXPECT_EQ(FPOne, Reg); +} + TEST_F(AArch64GISelMITest, MatchNeg) { setUp(); if (!TM)