Index: llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -328,6 +328,10 @@ bool matchCombineFAbsOfFAbs(MachineInstr &MI, Register &Src); bool applyCombineFAbsOfFAbs(MachineInstr &MI, Register &Src); + /// Transform add x, (mul y, z) to mad x, y, z + bool matchCombineMulAdd(MachineInstr &MI, std::tuple &MatchInfo); + bool applyCombineMulAdd(MachineInstr &MI, std::tuple &MatchInfo); + /// Transform trunc ([asz]ext x) to x or ([asz]ext x) or (trunc x). bool matchCombineTruncOfExt(MachineInstr &MI, std::pair &MatchInfo); Index: llvm/include/llvm/Target/GlobalISel/Combine.td =================================================================== --- llvm/include/llvm/Target/GlobalISel/Combine.td +++ llvm/include/llvm/Target/GlobalISel/Combine.td @@ -538,6 +538,14 @@ [{ return Helper.matchCombineInsertVecElts(*${root}, ${info}); }]), (apply [{ return Helper.applyCombineInsertVecElts(*${root}, ${info}); }])>; +// Transform add x, (mul y, z) -> mad x, y, z +def add_with_mul_info : GIDefMatchData<"std::tuple">; +def add_with_mul: GICombineRule< + (defs root:$root, add_with_mul_info:$info), + (match (wip_match_opcode G_FADD):$root, + [{ return Helper.matchCombineMulAdd(*${root}, ${info}); }]), + (apply [{ return Helper.applyCombineMulAdd(*${root}, ${info}); }])>; + // Currently only the one combine above. def insert_vec_elt_combines : GICombineGroup< [combine_insert_vec_elts_build_vector]>; @@ -580,4 +588,4 @@ unmerge_merge, fabs_fabs_fold, unmerge_cst, unmerge_dead_to_trunc, unmerge_zext_to_zext, trunc_ext_fold, trunc_shl, const_combines, xor_of_and_with_same_reg, ptr_add_with_zero, - shift_immed_chain, shift_of_shifted_logic_chain]>; + shift_immed_chain, shift_of_shifted_logic_chain, add_with_mul]>; Index: llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -3131,6 +3131,73 @@ return true; } +bool CombinerHelper::matchCombineMulAdd( + MachineInstr &MI, std::tuple &MatchInfo) { + assert(MI.getOpcode() == TargetOpcode::G_FADD); + + auto *MF = MI.getParent()->getParent(); + LLVMContext &C = MF->getFunction().getContext(); + const TargetOptions &Options = MF->getTarget().Options; + + unsigned TypeSize = MRI.getType(MI.getOperand(0).getReg()) + .getScalarSizeInBits(); + Type* Tp = nullptr; + if (TypeSize == 16) + Tp = Type::getBFloatTy(C); + if (TypeSize == 32) + Tp = Type::getFloatTy(C); + else if (TypeSize == 64) + Tp = Type::getDoubleTy(C); + else + return false; + + DenormalMode DenormMode = MF->getDenormalMode(Tp->getFltSemantics()); + + bool CanFuse = Options.UnsafeFPMath || + MI.getFlag(MachineInstr::MIFlag::FmContract); + + bool isIEE = DenormMode == DenormalMode::getIEEE(); + bool HasFMAD = !isIEE; + bool AllowFusionGlobally = (Options.AllowFPOpFusion == FPOpFusion::Fast || + CanFuse || HasFMAD); + + // If the addition is not contractable, do not combine. + if (!AllowFusionGlobally && !MI.getFlag(MachineInstr::MIFlag::FmContract)) + return false; + + MachineInstr *MIMul = MRI.getVRegDef(MI.getOperand(1).getReg()); + if (MIMul->getOpcode() == TargetOpcode::G_FMUL) { + MatchInfo = std::make_tuple(MIMul->getOperand(1).getReg(), + MIMul->getOperand(2).getReg(), + MI.getOperand(2).getReg()); + return true; + } + + MIMul = MRI.getVRegDef(MI.getOperand(2).getReg()); + if (MIMul->getOpcode() == TargetOpcode::G_FMUL) { + MatchInfo = std::make_tuple(MIMul->getOperand(1).getReg(), + MIMul->getOperand(2).getReg(), + MI.getOperand(1).getReg()); + return true; + } + + // Neither operand is a mul instruction. + if (MIMul->getOpcode() != TargetOpcode::G_FMUL) + return false; + return true; +} + +bool CombinerHelper::applyCombineMulAdd( + MachineInstr &MI, std::tuple &MatchInfo) { + Register Reg1, Reg2, Reg3; + std::tie(Reg1, Reg2, Reg3) = MatchInfo; + + Builder.setInstrAndDebugLoc(MI); + Builder.buildFMAD(MI.getOperand(0).getReg(), Reg1, Reg2, Reg3); + MI.eraseFromParent(); + return true; +} + bool CombinerHelper::tryCombine(MachineInstr &MI) { if (tryCombineCopy(MI)) return true; Index: llvm/test/CodeGen/AMDGPU/combine-g_fmad.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AMDGPU/combine-g_fmad.ll @@ -0,0 +1,71 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -march=amdgcn -mcpu=gfx1010 %s -o - | FileCheck -check-prefix=CHECK %s + +define amdgpu_vs float @test1(float inreg %arg1, +; CHECK-LABEL: test1: +; CHECK: ; %bb.0: ; %.entry +; CHECK-NEXT: v_mov_b32_e32 v0, s2 +; CHECK-NEXT: ; implicit-def: $vcc_hi +; CHECK-NEXT: v_mad_f32 v0, s1, s0, v0 +; CHECK-NEXT: ; return to shader part epilog + float inreg %arg2, + float inreg %arg3 + ) #1 { +.entry: + %t1 = fmul float %arg1, %arg2 + %res = fadd float %t1, %arg3 + ret float %res +} + +define amdgpu_vs float @test2(float inreg %arg1, +; CHECK-LABEL: test2: +; CHECK: ; %bb.0: ; %.entry +; CHECK-NEXT: v_mul_f32_e64 v0, s0, s1 +; CHECK-NEXT: ; implicit-def: $vcc_hi +; CHECK-NEXT: v_add_f32_e32 v0, s2, v0 +; CHECK-NEXT: ; return to shader part epilog + float inreg %arg2, + float inreg %arg3 + ) { +.entry: + %t1 = fmul float %arg1, %arg2 + %res = fadd float %t1, %arg3 + ret float %res +} + +define amdgpu_vs float @test3(float inreg %arg1, +; CHECK-LABEL: test3: +; CHECK: ; %bb.0: ; %.entry +; CHECK-NEXT: v_mul_f32_e64 v0, s0, s2 +; CHECK-NEXT: ; implicit-def: $vcc_hi +; CHECK-NEXT: v_mac_f32_e64 v0, s0, s1 +; CHECK-NEXT: ; return to shader part epilog + float inreg %arg2, + float inreg %arg3 + ) #1 { +.entry: + %t1 = fmul float %arg1, %arg2 + %t2 = fmul float %arg1, %arg3 + %res = fadd float %t1, %t2 + ret float %res +} + +define amdgpu_vs float @test4(float inreg %arg1, +; CHECK-LABEL: test4: +; CHECK: ; %bb.0: ; %.entry +; CHECK-NEXT: v_mul_f32_e64 v0, s0, s1 +; CHECK-NEXT: v_mul_f32_e64 v1, s0, s2 +; CHECK-NEXT: ; implicit-def: $vcc_hi +; CHECK-NEXT: v_add_f32_e32 v0, v0, v1 +; CHECK-NEXT: ; return to shader part epilog + float inreg %arg2, + float inreg %arg3 + ) { +.entry: + %t1 = fmul float %arg1, %arg2 + %t2 = fmul float %arg1, %arg3 + %res = fadd float %t1, %t2 + ret float %res +} + +attributes #1 = { "denormal-fp-math-f32"="preserve-sign" }