diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td --- a/llvm/lib/Target/AMDGPU/AMDGPUCombine.td +++ b/llvm/lib/Target/AMDGPU/AMDGPUCombine.td @@ -11,22 +11,22 @@ // TODO: This really belongs after legalization after scalarization. // TODO: GICombineRules should accept subtarget predicates -def fmin_fmax_legacy_matchdata : GIDefMatchData<"FMinFMaxLegacyInfo">; +def fmin_fmax_legacy_matchdata : GIDefMatchData<"AMDGPUPostLegalizerCombinerHelper::FMinFMaxLegacyInfo">; def fcmp_select_to_fmin_fmax_legacy : GICombineRule< (defs root:$select, fmin_fmax_legacy_matchdata:$matchinfo), (match (wip_match_opcode G_SELECT):$select, - [{ return matchFMinFMaxLegacy(*${select}, MRI, *MF, ${matchinfo}); }]), - (apply [{ applySelectFCmpToFMinToFMaxLegacy(*${select}, ${matchinfo}); }])>; + [{ return PostLegalizerHelper.matchFMinFMaxLegacy(*${select}, ${matchinfo}); }]), + (apply [{ PostLegalizerHelper.applySelectFCmpToFMinToFMaxLegacy(*${select}, ${matchinfo}); }])>; def uchar_to_float : GICombineRule< (defs root:$itofp), (match (wip_match_opcode G_UITOFP, G_SITOFP):$itofp, - [{ return matchUCharToFloat(*${itofp}, MRI, *MF, Helper); }]), - (apply [{ applyUCharToFloat(*${itofp}); }])>; + [{ return PostLegalizerHelper.matchUCharToFloat(*${itofp}); }]), + (apply [{ PostLegalizerHelper.applyUCharToFloat(*${itofp}); }])>; -def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">; +def cvt_f32_ubyteN_matchdata : GIDefMatchData<"AMDGPUPostLegalizerCombinerHelper::CvtF32UByteMatchInfo">; def cvt_f32_ubyteN : GICombineRule< (defs root:$cvt_f32_ubyteN, cvt_f32_ubyteN_matchdata:$matchinfo), @@ -34,8 +34,8 @@ G_AMDGPU_CVT_F32_UBYTE1, G_AMDGPU_CVT_F32_UBYTE2, G_AMDGPU_CVT_F32_UBYTE3):$cvt_f32_ubyteN, - [{ return matchCvtF32UByteN(*${cvt_f32_ubyteN}, MRI, *MF, ${matchinfo}); }]), - (apply [{ applyCvtF32UByteN(*${cvt_f32_ubyteN}, ${matchinfo}); }])>; + [{ return PostLegalizerHelper.matchCvtF32UByteN(*${cvt_f32_ubyteN}, ${matchinfo}); }]), + (apply [{ PostLegalizerHelper.applyCvtF32UByteN(*${cvt_f32_ubyteN}, ${matchinfo}); }])>; // Combines which should only apply on SI/VI def gfx6gfx7_combines : GICombineGroup<[fcmp_select_to_fmin_fmax_legacy]>; @@ -51,6 +51,8 @@ [all_combines, gfx6gfx7_combines, uchar_to_float, cvt_f32_ubyteN]> { let DisableRuleOption = "amdgpupostlegalizercombiner-disable-rule"; + let StateClass = "AMDGPUPostLegalizerCombinerHelperState"; + let AdditionalArguments = []; } def AMDGPURegBankCombinerHelper : GICombinerHelper< diff --git a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp @@ -29,17 +29,47 @@ using namespace llvm; using namespace MIPatternMatch; -struct FMinFMaxLegacyInfo { - Register LHS; - Register RHS; - Register True; - Register False; - CmpInst::Predicate Pred; +class AMDGPUPostLegalizerCombinerHelper { +protected: + MachineIRBuilder &B; + MachineFunction &MF; + MachineRegisterInfo &MRI; + CombinerHelper &Helper; + +public: + AMDGPUPostLegalizerCombinerHelper(MachineIRBuilder &B, CombinerHelper &Helper) + : B(B), MF(B.getMF()), MRI(*B.getMRI()), Helper(Helper){}; + + struct FMinFMaxLegacyInfo { + Register LHS; + Register RHS; + Register True; + Register False; + CmpInst::Predicate Pred; + }; + + // TODO: Make sure fmin_legacy/fmax_legacy don't canonicalize + bool matchFMinFMaxLegacy(MachineInstr &MI, FMinFMaxLegacyInfo &Info); + void applySelectFCmpToFMinToFMaxLegacy(MachineInstr &MI, + const FMinFMaxLegacyInfo &Info); + + bool matchUCharToFloat(MachineInstr &MI); + void applyUCharToFloat(MachineInstr &MI); + + // FIXME: Should be able to have 2 separate matchdatas rather than custom + // struct boilerplate. + struct CvtF32UByteMatchInfo { + Register CvtVal; + unsigned ShiftOffset; + }; + + bool matchCvtF32UByteN(MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo); + void applyCvtF32UByteN(MachineInstr &MI, + const CvtF32UByteMatchInfo &MatchInfo); }; -// TODO: Make sure fmin_legacy/fmax_legacy don't canonicalize -static bool matchFMinFMaxLegacy(MachineInstr &MI, MachineRegisterInfo &MRI, - MachineFunction &MF, FMinFMaxLegacyInfo &Info) { +bool AMDGPUPostLegalizerCombinerHelper::matchFMinFMaxLegacy( + MachineInstr &MI, FMinFMaxLegacyInfo &Info) { // FIXME: Combines should have subtarget predicates, and we shouldn't need // this here. if (!MF.getSubtarget().hasFminFmaxLegacy()) @@ -77,12 +107,11 @@ } } -static void applySelectFCmpToFMinToFMaxLegacy(MachineInstr &MI, - const FMinFMaxLegacyInfo &Info) { - - auto buildNewInst = [&MI](unsigned Opc, Register X, Register Y) { - MachineIRBuilder MIB(MI); - MIB.buildInstr(Opc, {MI.getOperand(0)}, {X, Y}, MI.getFlags()); +void AMDGPUPostLegalizerCombinerHelper::applySelectFCmpToFMinToFMaxLegacy( + MachineInstr &MI, const FMinFMaxLegacyInfo &Info) { + B.setInstrAndDebugLoc(MI); + auto buildNewInst = [&MI, this](unsigned Opc, Register X, Register Y) { + B.buildInstr(Opc, {MI.getOperand(0)}, {X, Y}, MI.getFlags()); }; switch (Info.Pred) { @@ -127,8 +156,7 @@ MI.eraseFromParent(); } -static bool matchUCharToFloat(MachineInstr &MI, MachineRegisterInfo &MRI, - MachineFunction &MF, CombinerHelper &Helper) { +bool AMDGPUPostLegalizerCombinerHelper::matchUCharToFloat(MachineInstr &MI) { Register DstReg = MI.getOperand(0).getReg(); // TODO: We could try to match extracting the higher bytes, which would be @@ -147,15 +175,15 @@ return false; } -static void applyUCharToFloat(MachineInstr &MI) { - MachineIRBuilder B(MI); +void AMDGPUPostLegalizerCombinerHelper::applyUCharToFloat(MachineInstr &MI) { + B.setInstrAndDebugLoc(MI); const LLT S32 = LLT::scalar(32); Register DstReg = MI.getOperand(0).getReg(); Register SrcReg = MI.getOperand(1).getReg(); - LLT Ty = B.getMRI()->getType(DstReg); - LLT SrcTy = B.getMRI()->getType(SrcReg); + LLT Ty = MRI.getType(DstReg); + LLT SrcTy = MRI.getType(SrcReg); if (SrcTy != S32) SrcReg = B.buildAnyExtOrTrunc(S32, SrcReg).getReg(0); @@ -171,16 +199,8 @@ MI.eraseFromParent(); } -// FIXME: Should be able to have 2 separate matchdatas rather than custom struct -// boilerplate. -struct CvtF32UByteMatchInfo { - Register CvtVal; - unsigned ShiftOffset; -}; - -static bool matchCvtF32UByteN(MachineInstr &MI, MachineRegisterInfo &MRI, - MachineFunction &MF, - CvtF32UByteMatchInfo &MatchInfo) { +bool AMDGPUPostLegalizerCombinerHelper::matchCvtF32UByteN( + MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) { Register SrcReg = MI.getOperand(1).getReg(); // Look through G_ZEXT. @@ -207,14 +227,14 @@ return false; } -static void applyCvtF32UByteN(MachineInstr &MI, - const CvtF32UByteMatchInfo &MatchInfo) { - MachineIRBuilder B(MI); +void AMDGPUPostLegalizerCombinerHelper::applyCvtF32UByteN( + MachineInstr &MI, const CvtF32UByteMatchInfo &MatchInfo) { + B.setInstrAndDebugLoc(MI); unsigned NewOpc = AMDGPU::G_AMDGPU_CVT_F32_UBYTE0 + MatchInfo.ShiftOffset / 8; const LLT S32 = LLT::scalar(32); Register CvtSrc = MatchInfo.CvtVal; - LLT SrcTy = B.getMRI()->getType(MatchInfo.CvtVal); + LLT SrcTy = MRI.getType(MatchInfo.CvtVal); if (SrcTy != S32) { assert(SrcTy.isScalar() && SrcTy.getSizeInBits() >= 8); CvtSrc = B.buildAnyExt(S32, CvtSrc).getReg(0); @@ -225,6 +245,18 @@ MI.eraseFromParent(); } +class AMDGPUPostLegalizerCombinerHelperState { +protected: + CombinerHelper &Helper; + AMDGPUPostLegalizerCombinerHelper &PostLegalizerHelper; + +public: + AMDGPUPostLegalizerCombinerHelperState( + CombinerHelper &Helper, + AMDGPUPostLegalizerCombinerHelper &PostLegalizerHelper) + : Helper(Helper), PostLegalizerHelper(PostLegalizerHelper) {} +}; + #define AMDGPUPOSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS #include "AMDGPUGenPostLegalizeGICombiner.inc" #undef AMDGPUPOSTLEGALIZERCOMBINERHELPER_GENCOMBINERHELPER_DEPS @@ -259,9 +291,11 @@ MachineInstr &MI, MachineIRBuilder &B) const { CombinerHelper Helper(Observer, B, KB, MDT, LInfo); - AMDGPUGenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg); + AMDGPUPostLegalizerCombinerHelper PostLegalizerHelper(B, Helper); + AMDGPUGenPostLegalizerCombinerHelper Generated(GeneratedRuleCfg, Helper, + PostLegalizerHelper); - if (Generated.tryCombineAll(Observer, MI, B, Helper)) + if (Generated.tryCombineAll(Observer, MI, B)) return true; switch (MI.getOpcode()) {