diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -656,6 +656,11 @@ /// (fadd (fmul x, y), z) -> (fmad x, y, z) bool matchCombineFAddFMulToFMadOrFMA(MachineInstr &MI, BuildFnTy &MatchInfo); + /// Transform (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) + /// (fadd (fpext (fmul x, y)), z) -> (fmad (fpext x), (fpext y), z) + bool matchCombineFAddFpExtFMulToFMadOrFMA(MachineInstr &MI, + BuildFnTy &MatchInfo); + private: /// Given a non-indexed load or store instruction \p MI, find an offset that /// can be usefully and legally folded into it as a post-indexing operation. diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2716,6 +2716,14 @@ return false; } + /// Return true if an fpext operation input to an \p Opcode operation is free + /// (for instance, because half-precision floating-point numbers are + /// implicitly extended to float-precision) for an FMA instruction. + virtual bool isFPExtFoldable(const MachineInstr &MI, unsigned Opcode, + LLT DestTy, LLT SrcTy) const { + return false; + } + /// Return true if an fpext operation input to an \p Opcode operation is free /// (for instance, because half-precision floating-point numbers are /// implicitly extended to float-precision) for an FMA instruction. diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -770,6 +770,16 @@ ${info}); }]), (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>; +// Transform (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) +// -> (fmad (fpext x), (fpext y), z) +// Transform (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) +// -> (fmad (fpext y), (fpext z), x) +def combine_fadd_fpext_fmul_to_fmad_or_fma: GICombineRule< + (defs root:$root, build_fn_matchinfo:$info), + (match (wip_match_opcode G_FADD):$root, + [{ return Helper.matchCombineFAddFpExtFMulToFMadOrFMA(*${root}, + ${info}); }]), + (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>; // FIXME: These should use the custom predicate feature once it lands. def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero, @@ -803,7 +813,8 @@ def trivial_combines : GICombineGroup<[copy_prop, mul_to_shl, add_p2i_to_ptradd, mul_by_neg_one]>; -def fma_combines : GICombineGroup<[combine_fadd_fmul_to_fmad_or_fma]>; +def fma_combines : GICombineGroup<[combine_fadd_fmul_to_fmad_or_fma, + combine_fadd_fpext_fmul_to_fmad_or_fma]>; def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines, extract_vec_elt_combines, combines_for_extload, diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -4893,6 +4893,67 @@ return false; } +bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( + MachineInstr &MI, std::function &MatchInfo) { + assert(MI.getOpcode() == TargetOpcode::G_FADD); + + bool AllowFusionGlobally, HasFMAD, Aggressive; + if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) + return false; + + const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); + MachineInstr *LHS = MRI.getVRegDef(MI.getOperand(1).getReg()); + MachineInstr *RHS = MRI.getVRegDef(MI.getOperand(2).getReg()); + LLT DstType = MRI.getType(MI.getOperand(0).getReg()); + + unsigned PreferredFusedOpcode = + HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; + + // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), + // prefer to fold the multiply with fewer uses. + if (Aggressive && isContractableFMul(*LHS, AllowFusionGlobally) && + isContractableFMul(*RHS, AllowFusionGlobally)) { + if (hasMoreUses(*LHS, *RHS, MRI)) + std::swap(LHS, RHS); + } + + // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) + MachineInstr *FpExtSrc; + if (mi_match(LHS->getOperand(0).getReg(), MRI, + m_GFPExt(m_MInstr(FpExtSrc))) && + isContractableFMul(*FpExtSrc, AllowFusionGlobally) && + TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, + MRI.getType(FpExtSrc->getOperand(1).getReg()))) { + MatchInfo = [=, &MI](MachineIRBuilder &B) { + auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg()); + auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg()); + B.buildInstr( + PreferredFusedOpcode, {MI.getOperand(0).getReg()}, + {FpExtX.getReg(0), FpExtY.getReg(0), RHS->getOperand(0).getReg()}); + }; + return true; + } + + // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) + // Note: Commutes FADD operands. + if (mi_match(RHS->getOperand(0).getReg(), MRI, + m_GFPExt(m_MInstr(FpExtSrc))) && + isContractableFMul(*FpExtSrc, AllowFusionGlobally) && + TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, + MRI.getType(FpExtSrc->getOperand(1).getReg()))) { + MatchInfo = [=, &MI](MachineIRBuilder &B) { + auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg()); + auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg()); + B.buildInstr( + PreferredFusedOpcode, {MI.getOperand(0).getReg()}, + {FpExtX.getReg(0), FpExtY.getReg(0), LHS->getOperand(0).getReg()}); + }; + return true; + } + + return false; +} + bool CombinerHelper::tryCombine(MachineInstr &MI) { if (tryCombineCopy(MI)) return true; diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h --- a/llvm/lib/Target/AMDGPU/SIISelLowering.h +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h @@ -253,6 +253,9 @@ bool isFPExtFoldable(const SelectionDAG &DAG, unsigned Opcode, EVT DestVT, EVT SrcVT) const override; + bool isFPExtFoldable(const MachineInstr &MI, unsigned Opcode, LLT DestTy, + LLT SrcTy) const override; + bool isShuffleMaskLegal(ArrayRef /*Mask*/, EVT /*VT*/) const override; bool getTgtMemIntrinsic(IntrinsicInfo &, const CallInst &, diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -924,6 +924,16 @@ !hasFP32Denormals(DAG.getMachineFunction()); } +bool SITargetLowering::isFPExtFoldable(const MachineInstr &MI, unsigned Opcode, + LLT DestTy, LLT SrcTy) const { + return ((Opcode == TargetOpcode::G_FMAD && Subtarget->hasMadMixInsts()) || + (Opcode == TargetOpcode::G_FMA && Subtarget->hasFmaMixInsts())) && + DestTy.getScalarSizeInBits() == 32 && + SrcTy.getScalarSizeInBits() == 16 && + // TODO: This probably only requires no input flushing? + !hasFP32Denormals(*MI.getMF()); +} + bool SITargetLowering::isShuffleMaskLegal(ArrayRef, EVT) const { // SI has some legal vector types, but no legal vector operations. Say no // shuffles are legal in order to prefer scalarizing some vector operations. diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-add-ext-mul.ll b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-add-ext-mul.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fma-add-ext-mul.ll @@ -0,0 +1,159 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -global-isel -march=amdgcn -mcpu=gfx900 --denormal-fp-math=preserve-sign < %s | FileCheck -check-prefix=GFX9-FAST-DENORM %s +; RUN: llc -global-isel -march=amdgcn -mcpu=gfx1010 --denormal-fp-math=preserve-sign < %s | FileCheck -check-prefix=GFX10-FAST-DENORM %s + +; fold (fadd fast (fpext (fmul fast x, y)), z) -> (fma (fpext x), (fpext y), z) +; fold (fadd fast x, (fpext (fmul fast y, z))) -> (fma (fpext y), (fpext z), x) + +define amdgpu_vs float @test_f16_f32_add_ext_mul(half inreg %x, half inreg %y, float inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX9-FAST-DENORM-NEXT: v_mad_f32 v0, v0, v1, s2 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v1, s2 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast half %x, %y + %b = fpext half %a to float + %c = fadd fast float %b, %z + ret float %c +} + +define amdgpu_vs float @test_f16_f32_add_ext_mul_rhs(half inreg %x, half inreg %y, float inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul_rhs: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX9-FAST-DENORM-NEXT: v_mad_f32 v0, v0, v1, s2 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_f16_f32_add_ext_mul_rhs: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s1 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v1, s2 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast half %x, %y + %b = fpext half %a to float + %c = fadd fast float %z, %b + ret float %c +} + +define amdgpu_vs <5 x float> @test_5xf16_5xf32_add_ext_mul(<5 x half> inreg %x, <5 x half> inreg %y, <5 x float> inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_5xf16_5xf32_add_ext_mul: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s3, s3, s3 +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s4, s4, s4 +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s0, s0, s0 +; GFX9-FAST-DENORM-NEXT: s_pack_lh_b32_b16 s1, s1, s1 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v0, s3 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v1, s4 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v2, s5 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v0, s0, v0 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v1, s1, v1 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v2, s2, v2 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, v0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v4, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, v1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v6, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, v2 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v0, s6, v3 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v1, s7, v4 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v2, s8, v5 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v3, s9, v6 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v4, s10, v7 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_5xf16_5xf32_add_ext_mul: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s11, s0, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s12, s1, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s13, s3, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s14, s4, 16 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s11 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v2, s1 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, s12 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v4, s2 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, s3 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v6, s13 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, s4 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v8, s14 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v9, s5 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v5, s6 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v1, v1, v6, s7 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v2, v2, v7, s8 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v3, v3, v8, s9 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v4, v4, v9, s10 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast <5 x half> %x, %y + %b = fpext <5 x half> %a to <5 x float> + %c = fadd fast <5 x float> %b, %z + ret <5 x float> %c +} + +define amdgpu_vs <6 x float> @test_6xf16_6xf32_add_ext_mul_rhs(<6 x half> inreg %x, <6 x half> inreg %y, <6 x float> inreg %z) { +; GFX9-FAST-DENORM-LABEL: test_6xf16_6xf32_add_ext_mul_rhs: +; GFX9-FAST-DENORM: ; %bb.0: ; %.entry +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v0, s3 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v1, s4 +; GFX9-FAST-DENORM-NEXT: v_mov_b32_e32 v2, s5 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v0, s0, v0 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v1, s1, v1 +; GFX9-FAST-DENORM-NEXT: v_pk_mul_f16 v2, s2, v2 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, v0 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v4, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, v1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v6, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, v2 +; GFX9-FAST-DENORM-NEXT: v_cvt_f32_f16_sdwa v8, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v0, s6, v3 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v1, s7, v4 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v2, s8, v5 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v3, s9, v6 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v4, s10, v7 +; GFX9-FAST-DENORM-NEXT: v_add_f32_e32 v5, s11, v8 +; GFX9-FAST-DENORM-NEXT: ; return to shader part epilog +; +; GFX10-FAST-DENORM-LABEL: test_6xf16_6xf32_add_ext_mul_rhs: +; GFX10-FAST-DENORM: ; %bb.0: ; %.entry +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s12, s0, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s13, s1, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s14, s2, 16 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v0, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v2, s1 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v4, s2 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s0, s3, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s1, s4, 16 +; GFX10-FAST-DENORM-NEXT: s_lshr_b32 s2, s5, 16 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v1, s12 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v3, s13 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v5, s14 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v6, s3 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v7, s0 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v8, s4 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v9, s1 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v10, s5 +; GFX10-FAST-DENORM-NEXT: v_cvt_f32_f16_e32 v11, s2 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v0, v0, v6, s6 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v1, v1, v7, s7 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v2, v2, v8, s8 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v3, v3, v9, s9 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v4, v4, v10, s10 +; GFX10-FAST-DENORM-NEXT: v_fma_f32 v5, v5, v11, s11 +; GFX10-FAST-DENORM-NEXT: ; return to shader part epilog +.entry: + %a = fmul fast <6 x half> %x, %y + %b = fpext <6 x half> %a to <6 x float> + %c = fadd fast <6 x float> %z, %b + ret <6 x float> %c +}