diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -94,6 +94,48 @@ return ConstantMatch(Cst); } +template +inline Optional matchConstantSplat(Register, + const MachineRegisterInfo &); + +template <> +inline Optional matchConstantSplat(Register Reg, + const MachineRegisterInfo &MRI) { + return getIConstantSplatVal(Reg, MRI); +} + +template <> +inline Optional matchConstantSplat(Register Reg, + const MachineRegisterInfo &MRI) { + return getIConstantSplatSExtVal(Reg, MRI); +} + +template struct ICstOrSplatMatch { + ConstT &CR; + ICstOrSplatMatch(ConstT &C) : CR(C) {} + bool match(const MachineRegisterInfo &MRI, Register Reg) { + if (auto MaybeCst = matchConstant(Reg, MRI)) { + CR = *MaybeCst; + return true; + } + + if (auto MaybeCstSplat = matchConstantSplat(Reg, MRI)) { + CR = *MaybeCstSplat; + return true; + } + + return false; + }; +}; + +inline ICstOrSplatMatch m_ICstOrSplat(APInt &Cst) { + return ICstOrSplatMatch(Cst); +} + +inline ICstOrSplatMatch m_ICstOrSplat(int64_t &Cst) { + return ICstOrSplatMatch(Cst); +} + struct GCstAndRegMatch { Optional &ValReg; GCstAndRegMatch(Optional &ValReg) : ValReg(ValReg) {} diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h --- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h @@ -373,9 +373,23 @@ /// If \p MI is not a splat, returns None. Optional getSplatIndex(MachineInstr &MI); -/// Returns a scalar constant of a G_BUILD_VECTOR splat if it exists. -Optional getBuildVectorConstantSplat(const MachineInstr &MI, - const MachineRegisterInfo &MRI); +/// \returns the scalar integral splat value of \p Reg if possible. +Optional getIConstantSplatVal(const Register Reg, + const MachineRegisterInfo &MRI); + +/// \returns the scalar integral splat value defined by \p MI if possible. +Optional getIConstantSplatVal(const MachineInstr &MI, + const MachineRegisterInfo &MRI); + +/// \returns the scalar sign extended integral splat value of \p Reg if +/// possible. +Optional getIConstantSplatSExtVal(const Register Reg, + const MachineRegisterInfo &MRI); + +/// \returns the scalar sign extended integral splat value defined by \p MI if +/// possible. +Optional getIConstantSplatSExtVal(const MachineInstr &MI, + const MachineRegisterInfo &MRI); /// Returns a floating point scalar constant of a build vector splat if it /// exists. When \p AllowUndef == true some elements can be undef but not all. diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -2945,7 +2945,7 @@ int64_t Cst; if (Ty.isVector()) { MachineInstr *CstDef = MRI.getVRegDef(CstReg); - auto MaybeCst = getBuildVectorConstantSplat(*CstDef, MRI); + auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI); if (!MaybeCst) return false; if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP)) @@ -4029,10 +4029,9 @@ // Given constants C0 and C1 such that C0 + C1 is bit-width: // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) - // TODO: Match constant splat. int64_t CstShlAmt, CstLShrAmt; - if (mi_match(ShlAmt, MRI, m_ICst(CstShlAmt)) && - mi_match(LShrAmt, MRI, m_ICst(CstLShrAmt)) && + if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) && + mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) && CstShlAmt + CstLShrAmt == BitWidth) { FshOpc = TargetOpcode::G_FSHR; Amt = LShrAmt; diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp --- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp +++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp @@ -1071,15 +1071,38 @@ AllowUndef); } +Optional llvm::getIConstantSplatVal(const Register Reg, + const MachineRegisterInfo &MRI) { + if (auto SplatValAndReg = + getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false)) { + Optional ValAndVReg = + getIConstantVRegValWithLookThrough(SplatValAndReg->VReg, MRI); + return ValAndVReg->Value; + } + + return None; +} + +Optional getIConstantSplatVal(const MachineInstr &MI, + const MachineRegisterInfo &MRI) { + return getIConstantSplatVal(MI.getOperand(0).getReg(), MRI); +} + Optional -llvm::getBuildVectorConstantSplat(const MachineInstr &MI, - const MachineRegisterInfo &MRI) { +llvm::getIConstantSplatSExtVal(const Register Reg, + const MachineRegisterInfo &MRI) { if (auto SplatValAndReg = - getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, false)) + getAnyConstantSplat(Reg, MRI, /* AllowUndef */ false)) return getIConstantVRegSExtVal(SplatValAndReg->VReg, MRI); return None; } +Optional +llvm::getIConstantSplatSExtVal(const MachineInstr &MI, + const MachineRegisterInfo &MRI) { + return getIConstantSplatSExtVal(MI.getOperand(0).getReg(), MRI); +} + Optional llvm::getFConstantSplat(Register VReg, const MachineRegisterInfo &MRI, bool AllowUndef) { @@ -1105,7 +1128,7 @@ unsigned Opc = MI.getOpcode(); if (!isBuildVectorOp(Opc)) return None; - if (auto Splat = getBuildVectorConstantSplat(MI, MRI)) + if (auto Splat = getIConstantSplatSExtVal(MI, MRI)) return RegOrConstant(*Splat); auto Reg = MI.getOperand(1).getReg(); if (any_of(make_range(MI.operands_begin() + 2, MI.operands_end()), @@ -1176,7 +1199,7 @@ Register Def = MI.getOperand(0).getReg(); if (auto C = getIConstantVRegValWithLookThrough(Def, MRI)) return C->Value; - auto MaybeCst = getBuildVectorConstantSplat(MI, MRI); + auto MaybeCst = getIConstantSplatSExtVal(MI, MRI); if (!MaybeCst) return None; const unsigned ScalarSize = MRI.getType(Def).getScalarSizeInBits(); diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir --- a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir +++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir @@ -143,13 +143,9 @@ ; CHECK-NEXT: {{ $}} ; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1 ; CHECK-NEXT: %b:_(<2 x s32>) = COPY $vgpr2_vgpr3 - ; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20 - ; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32) ; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12 ; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32) - ; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>) - ; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %b, %amt1(<2 x s32>) - ; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr + ; CHECK-NEXT: %or:_(<2 x s32>) = G_FSHR %a, %b, %amt1(<2 x s32>) ; CHECK-NEXT: $vgpr4_vgpr5 = COPY %or(<2 x s32>) %a:_(<2 x s32>) = COPY $vgpr0_vgpr1 %b:_(<2 x s32>) = COPY $vgpr2_vgpr3 diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir --- a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir +++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir @@ -132,13 +132,9 @@ ; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3 ; CHECK-NEXT: {{ $}} ; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1 - ; CHECK-NEXT: %scalar_amt0:_(s32) = G_CONSTANT i32 20 - ; CHECK-NEXT: %amt0:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt0(s32), %scalar_amt0(s32) ; CHECK-NEXT: %scalar_amt1:_(s32) = G_CONSTANT i32 12 ; CHECK-NEXT: %amt1:_(<2 x s32>) = G_BUILD_VECTOR %scalar_amt1(s32), %scalar_amt1(s32) - ; CHECK-NEXT: %shl:_(<2 x s32>) = G_SHL %a, %amt0(<2 x s32>) - ; CHECK-NEXT: %lshr:_(<2 x s32>) = G_LSHR %a, %amt1(<2 x s32>) - ; CHECK-NEXT: %or:_(<2 x s32>) = G_OR %shl, %lshr + ; CHECK-NEXT: %or:_(<2 x s32>) = G_ROTR %a, %amt1(<2 x s32>) ; CHECK-NEXT: $vgpr2_vgpr3 = COPY %or(<2 x s32>) %a:_(<2 x s32>) = COPY $vgpr0_vgpr1 %scalar_amt0:_(s32) = G_CONSTANT i32 20 diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -51,6 +51,25 @@ EXPECT_EQ(Src0->VReg, MIBCst.getReg(0)); } +TEST_F(AArch64GISelMITest, MatchIntConstantSplat) { + setUp(); + if (!TM) + return; + + LLT s64 = LLT::scalar(64); + LLT v4s64 = LLT::fixed_vector(4, s64); + + MachineInstrBuilder FortyTwoSplat = + B.buildSplatVector(v4s64, B.buildConstant(s64, 42)); + int64_t Cst; + EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, m_ICstOrSplat(Cst))); + EXPECT_EQ(Cst, 42); + + MachineInstrBuilder NonConstantSplat = + B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]}); + EXPECT_FALSE(mi_match(NonConstantSplat.getReg(0), *MRI, m_ICstOrSplat(Cst))); +} + TEST_F(AArch64GISelMITest, MachineInstrPtrBind) { setUp(); if (!TM)