Index: llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -346,6 +346,15 @@ MachineInstr &MI, std::tuple &MatchInfo); + // Transform (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) + // (fadd (fpext (fmul x, y)), z) -> (fmad (fpext x), (fpext y), z) + bool matchCombineFAddFpExtFMulToFMadOrFMA( + MachineInstr &MI, + std::tuple &MatchInfo); + bool applyCombineFAddFpExtFMulToFMadOrFMA( + MachineInstr &MI, + std::tuple &MatchInfo); + /// Transform trunc ([asz]ext x) to x or ([asz]ext x) or (trunc x). bool matchCombineTruncOfExt(MachineInstr &MI, std::pair &MatchInfo); Index: llvm/include/llvm/Target/GlobalISel/Combine.td =================================================================== --- llvm/include/llvm/Target/GlobalISel/Combine.td +++ llvm/include/llvm/Target/GlobalISel/Combine.td @@ -590,6 +590,18 @@ (apply [{ return Helper.applyCombineFAddFMulToFMadOrFMA(*${root}, ${info}); }])>; +// Transform (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) +// (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) +def combine_fadd_fpext_fmul_to_fmad_or_fma_info : + GIDefMatchData<"std::tuple">; +def combine_fadd_fpext_fmul_to_fmad_or_fma: GICombineRule< + (defs root:$root, combine_fadd_fpext_fmul_to_fmad_or_fma_info:$info), + (match (wip_match_opcode G_FADD):$root, + [{ return Helper.matchCombineFAddFpExtFMulToFMadOrFMA(*${root}, + ${info}); }]), + (apply [{ return Helper.applyCombineFAddFpExtFMulToFMadOrFMA(*${root}, + ${info}); }])>; + // Currently only the one combine above. def insert_vec_elt_combines : GICombineGroup< [combine_insert_vec_elts_build_vector]>; @@ -672,7 +684,7 @@ const_combines, xor_of_and_with_same_reg, ptr_add_with_zero, shift_immed_chain, shift_of_shifted_logic_chain, load_or_combine, div_rem_to_divrem, funnel_shift_combines, - combine_fadd_fmul_to_fmad_or_fma]>; + combine_fadd_fmul_to_fmad_or_fma, combine_fadd_fpext_fmul_to_fmad_or_fma]>; // A combine group used to for prelegalizer combiners at -O0. The combines in // this group have been selected based on experiments to balance code size and Index: llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -4025,6 +4025,116 @@ return true; } +bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( + MachineInstr &MI, + std::tuple &MatchInfo) { + assert(MI.getOpcode() == TargetOpcode::G_FADD); + + auto *MF = MI.getParent()->getParent(); + const auto &TLI = *MF->getSubtarget().getTargetLowering(); + const TargetOptions &Options = MF->getTarget().Options; + LLT DstType = MRI.getType(MI.getOperand(0).getReg()); + MachineInstr *MI0 = MRI.getVRegDef(MI.getOperand(1).getReg()); + MachineInstr *MI1 = MRI.getVRegDef(MI.getOperand(2).getReg()); + + bool LegalOperations = LI; + // Floating-point multiply-add with intermediate rounding. + bool HasFMAD = (LegalOperations && TLI.isFMADLegal(MI, DstType)); + // Floating-point multiply-add without intermediate rounding. + bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(*MF, DstType) && + isLegalOrBeforeLegalizer({TargetOpcode::G_FMA, {DstType}}); + + // No valid opcode, do not combine. + if (!HasFMAD && !HasFMA) + return false; + + bool CanFuse = Options.UnsafeFPMath || isContractable(MI); + bool AllowFusionGlobally = + (Options.AllowFPOpFusion == FPOpFusion::Fast || CanFuse || HasFMAD); + + // If the addition is not contractable, do not combine. + if (!AllowFusionGlobally && !isContractable(MI)) + return false; + + unsigned PreferredFusedOpcode = + HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; + bool Aggressive = TLI.enableAggressiveFMAFusion(DstType); + + // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), + // prefer to fold the multiply with fewer uses. + if (Aggressive && isContractableFMul(*MI0, AllowFusionGlobally) && + isContractableFMul(*MI1, AllowFusionGlobally)) { + if (hasMoreUses(*MI0, *MI1, MRI)) + std::swap(MI0, MI1); + } + + Register FpExtReg; + if (mi_match(MI0->getOperand(0).getReg(), MRI, m_GFPExt(m_Reg(FpExtReg)))) { + MachineInstr *MI00 = MRI.getVRegDef(FpExtReg); + if (isContractableFMul(*MI00, AllowFusionGlobally) && + TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, + MRI.getType(MI00->getOperand(1).getReg()))) { + MatchInfo = {MI00->getOperand(1).getReg(), MI00->getOperand(2).getReg(), + MI1->getOperand(0).getReg(), PreferredFusedOpcode}; + return true; + } + } + // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) + /*if (MI0->getOpcode() == TargetOpcode::G_FPEXT) { + MachineInstr *MI00 = MRI.getVRegDef(MI0->getOperand(1).getReg()); + if (isContractableFMul(*MI00, AllowFusionGlobally) && + TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, + MRI.getType(MI00->getOperand(1).getReg()))) { + MatchInfo = {MI00->getOperand(1).getReg(), MI00->getOperand(2).getReg(), + MI1->getOperand(0).getReg(), PreferredFusedOpcode}; + return true; + } + }*/ + + if (mi_match(MI1->getOperand(0).getReg(), MRI, m_GFPExt(m_Reg(FpExtReg)))) { + MachineInstr *MI10 = MRI.getVRegDef(FpExtReg); + if (isContractableFMul(*MI10, AllowFusionGlobally) && + TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, + MRI.getType(MI10->getOperand(1).getReg()))) { + MatchInfo = {MI10->getOperand(1).getReg(), MI10->getOperand(2).getReg(), + MI0->getOperand(0).getReg(), PreferredFusedOpcode}; + return true; + } + } + // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) + // Note: Commutes FADD operands. + /*if (MI1->getOpcode() == TargetOpcode::G_FPEXT) { + MachineInstr *MI10 = MRI.getVRegDef(MI1->getOperand(1).getReg()); + if (isContractableFMul(*MI10, AllowFusionGlobally) && + TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, + MRI.getType(MI10->getOperand(1).getReg()))) { + MatchInfo = {MI10->getOperand(1).getReg(), MI10->getOperand(2).getReg(), + MI0->getOperand(0).getReg(), PreferredFusedOpcode}; + return true; + } + }*/ + + return false; +} + +bool CombinerHelper::applyCombineFAddFpExtFMulToFMadOrFMA( + MachineInstr &MI, + std::tuple &MatchInfo) { + Register Src1, Src2, Src3; + unsigned PreferredFusedOpcode; + LLT DstType = MRI.getType(MI.getOperand(0).getReg()); + std::tie(Src1, Src2, Src3, PreferredFusedOpcode) = MatchInfo; + + Builder.setInstrAndDebugLoc(MI); + auto FpExt1 = Builder.buildFPExt(DstType, Src1); + auto FpExt2 = Builder.buildFPExt(DstType, Src2); + Builder.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, + {FpExt1.getReg(0), FpExt2.getReg(0), Src3}); + MI.eraseFromParent(); + + return true; +} + bool CombinerHelper::tryCombine(MachineInstr &MI) { if (tryCombineCopy(MI)) return true; Index: llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-add-ext-mul.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-add-ext-mul.ll @@ -0,0 +1,156 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -global-isel -march=amdgcn -mcpu=gfx900 --denormal-fp-math=preserve-sign < %s | FileCheck -check-prefix=GFX9-FAST-DENORM %s +; RUN: llc -global-isel -march=amdgcn -mcpu=gfx1010 --denormal-fp-math=preserve-sign < %s | FileCheck -check-prefix=GFX10-FAST-DENORM %s + +; fold (fadd fast (fpext (fmul fast x, y)), z) -> (fma (fpext x), (fpext y), z) +; fold (fadd fast x, (fpext (fmul fast y, z))) -> (fma (fpext y), (fpext z), x) + +define amdgpu_vs float @test_f16_f32_add_ext_mul(half inreg %x, half inreg %y, float inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX9-FAST-DENORM-NEXT: v_mad_f32 v0, v0, v1, s2 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v1, s2 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast half %x, %y + %b = fpext half %a to float + %c = fadd fast float %b, %z + ret float %c +} +define amdgpu_vs float @test_f16_f32_add_ext_mul_rhs(half inreg %x, half inreg %y, float inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul_rhs: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX9-FAST-DENORM-NEXT: v_mad_f32 v0, v0, v1, s2 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul_rhs: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v1, s2 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast half %x, %y + %b = fpext half %a to float + %c = fadd fast float %z, %b + ret float %c +} +define amdgpu_vs <5 x float> @test_5xf16_5xf32_add_ext_mul(<5 x half> inreg %x, <5 x half> inreg %y, <5 x float> inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_5xf16_5xf32_add_ext_mul: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s3, s3, s3 +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s4, s4, s4 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v2, s5 +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s0, s0, s0 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v0, s3 +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s1, s1, s1 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v1, s4 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v0, s0, v0 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v1, s1, v1 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v2, s2, v2 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, v0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v4, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, v1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v6, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, v2 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v0, s6, v3 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v1, s7, v4 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v2, s8, v5 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v3, s9, v6 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v4, s10, v7 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_5xf16_5xf32_add_ext_mul: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s11, s0, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s12, s1, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s13, s3, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s14, s4, 16 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, s3 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s11 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v6, s13 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v2, s1 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, s4 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, s12 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v8, s14 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v4, s2 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v9, s5 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v5, s6 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v1, v1, v6, s7 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v2, v2, v7, s8 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v3, v3, v8, s9 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v4, v4, v9, s10 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast <5 x half> %x, %y + %b = fpext <5 x half> %a to <5 x float> + %c = fadd fast <5 x float> %b, %z + ret <5 x float> %c +} +define amdgpu_vs <6 x float> @test_6xf16_6xf32_add_ext_mul_rhs(<6 x half> inreg %x, <6 x half> inreg %y, <6 x float> inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_6xf16_6xf32_add_ext_mul_rhs: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v0, s3 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v1, s4 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v2, s5 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v0, s0, v0 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v1, s1, v1 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v2, s2, v2 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, v0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v4, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, v1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v6, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, v2 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v8, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v0, s6, v3 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v1, s7, v4 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v2, s8, v5 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v3, s9, v6 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v4, s10, v7 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v5, s11, v8 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_6xf16_6xf32_add_ext_mul_rhs: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s12, s0, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s13, s1, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s14, s2, 16 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v2, s1 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v4, s2 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s0, s3, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s1, s4, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s2, s5, 16 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v6, s3 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v8, s4 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v10, s5 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s12 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, s13 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v9, s1 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, s14 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v11, s2 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v6, s6 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v1, v1, v7, s7 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v2, v2, v8, s8 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v3, v3, v9, s9 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v4, v4, v10, s10 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v5, v5, v11, s11 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast <6 x half> %x, %y + %b = fpext <6 x half> %a to <6 x float> + %c = fadd fast <6 x float> %z, %b + ret <6 x float> %c +}