diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -178,6 +178,8 @@ // X86 VNNI DPWSSD, + + FNMADD, }; } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -5409,6 +5409,39 @@ return Found; } +static bool getFNEGPatterns(MachineInstr &Root, + SmallVectorImpl &Patterns) { + unsigned Opc = Root.getOpcode(); + MachineBasicBlock &MBB = *Root.getParent(); + MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + + auto Match = [&](unsigned Opcode, MachineCombinerPattern Pattern) -> bool { + MachineOperand &MO = Root.getOperand(1); + MachineInstr *MI = MRI.getUniqueVRegDef(MO.getReg()); + if (MI != nullptr && MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()) && + (MI->getOpcode() == Opcode) && + Root.getFlag(MachineInstr::MIFlag::FmContract) && + Root.getFlag(MachineInstr::MIFlag::FmNsz) && + MI->getFlag(MachineInstr::MIFlag::FmContract) && + MI->getFlag(MachineInstr::MIFlag::FmNsz)) { + Patterns.push_back(Pattern); + return true; + } + return false; + }; + + switch (Opc) { + default: + break; + case AArch64::FNEGDr: + return Match(AArch64::FMADDDrrr, MachineCombinerPattern::FNMADD); + case AArch64::FNEGSr: + return Match(AArch64::FMADDSrrr, MachineCombinerPattern::FNMADD); + } + + return false; +} + /// Return true when a code sequence can improve throughput. It /// should be called only for instructions in loops. /// \param Pattern - combiner pattern @@ -5578,6 +5611,8 @@ return true; if (getFMAPatterns(Root, Patterns)) return true; + if (getFNEGPatterns(Root, Patterns)) + return true; // Other patterns if (getMiscPatterns(Root, Patterns)) @@ -5668,6 +5703,47 @@ return MUL; } +static MachineInstr * +genFNegatedMAD(MachineFunction &MF, MachineRegisterInfo &MRI, + const TargetInstrInfo *TII, MachineInstr &Root, + SmallVectorImpl &InsInstrs) { + MachineInstr *MAD = MRI.getUniqueVRegDef(Root.getOperand(1).getReg()); + + unsigned Opc = 0; + const TargetRegisterClass *RC = MRI.getRegClass(MAD->getOperand(0).getReg()); + if (AArch64::FPR32RegClass.hasSubClassEq(RC)) + Opc = AArch64::FNMADDSrrr; + else if (AArch64::FPR64RegClass.hasSubClassEq(RC)) + Opc = AArch64::FNMADDDrrr; + else + return nullptr; + + Register ResultReg = Root.getOperand(0).getReg(); + Register SrcReg0 = MAD->getOperand(1).getReg(); + Register SrcReg1 = MAD->getOperand(2).getReg(); + Register SrcReg2 = MAD->getOperand(3).getReg(); + bool Src0IsKill = MAD->getOperand(1).isKill(); + bool Src1IsKill = MAD->getOperand(2).isKill(); + bool Src2IsKill = MAD->getOperand(3).isKill(); + if (ResultReg.isVirtual()) + MRI.constrainRegClass(ResultReg, RC); + if (SrcReg0.isVirtual()) + MRI.constrainRegClass(SrcReg0, RC); + if (SrcReg1.isVirtual()) + MRI.constrainRegClass(SrcReg1, RC); + if (SrcReg2.isVirtual()) + MRI.constrainRegClass(SrcReg2, RC); + + MachineInstrBuilder MIB = + BuildMI(MF, MIMetadata(Root), TII->get(Opc), ResultReg) + .addReg(SrcReg0, getKillRegState(Src0IsKill)) + .addReg(SrcReg1, getKillRegState(Src1IsKill)) + .addReg(SrcReg2, getKillRegState(Src2IsKill)); + InsInstrs.push_back(MIB); + + return MAD; +} + /// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane) static MachineInstr * genIndexedMultiply(MachineInstr &Root, @@ -6800,6 +6876,11 @@ &AArch64::FPR128_loRegClass, MRI); break; } + case MachineCombinerPattern::FNMADD: { + MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs); + break; + } + } // end switch (Pattern) // Record MUL and ADD/SUB for deletion if (MUL) diff --git a/llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll b/llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64_fnmadd.ll @@ -0,0 +1,153 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2 +; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 -verify-machineinstrs | FileCheck %s + +define void @fnmaddd(ptr %a, ptr %b, ptr %c) { +; CHECK-LABEL: fnmaddd: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr d0, [x1] +; CHECK-NEXT: ldr d1, [x0] +; CHECK-NEXT: ldr d2, [x2] +; CHECK-NEXT: fnmadd d0, d0, d1, d2 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret +entry: + %0 = load double, ptr %a, align 8 + %1 = load double, ptr %b, align 8 + %mul = fmul fast double %1, %0 + %2 = load double, ptr %c, align 8 + %add = fadd fast double %mul, %2 + %fneg = fneg fast double %add + store double %fneg, ptr %a, align 8 + ret void +} + +; Don't combine: No flags +define void @fnmaddd_no_fast(ptr %a, ptr %b, ptr %c) { +; CHECK-LABEL: fnmaddd_no_fast: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr d0, [x0] +; CHECK-NEXT: ldr d1, [x1] +; CHECK-NEXT: fmul d0, d1, d0 +; CHECK-NEXT: ldr d1, [x2] +; CHECK-NEXT: fadd d0, d0, d1 +; CHECK-NEXT: fneg d0, d0 +; CHECK-NEXT: str d0, [x0] +; CHECK-NEXT: ret +entry: + %0 = load double, ptr %a, align 8 + %1 = load double, ptr %b, align 8 + %mul = fmul double %1, %0 + %2 = load double, ptr %c, align 8 + %add = fadd double %mul, %2 + %fneg = fneg double %add + store double %fneg, ptr %a, align 8 + ret void +} + +define void @fnmadds(ptr %a, ptr %b, ptr %c) { +; CHECK-LABEL: fnmadds: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr s0, [x1] +; CHECK-NEXT: ldr s1, [x0] +; CHECK-NEXT: ldr s2, [x2] +; CHECK-NEXT: fnmadd s0, s0, s1, s2 +; CHECK-NEXT: str s0, [x0] +; CHECK-NEXT: ret +entry: + %0 = load float, ptr %a, align 4 + %1 = load float, ptr %b, align 4 + %mul = fmul fast float %1, %0 + %2 = load float, ptr %c, align 4 + %add = fadd fast float %mul, %2 + %fneg = fneg fast float %add + store float %fneg, ptr %a, align 4 + ret void +} + +define void @fnmadds_nsz_contract(ptr %a, ptr %b, ptr %c) { +; CHECK-LABEL: fnmadds_nsz_contract: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr s0, [x1] +; CHECK-NEXT: ldr s1, [x0] +; CHECK-NEXT: ldr s2, [x2] +; CHECK-NEXT: fnmadd s0, s0, s1, s2 +; CHECK-NEXT: str s0, [x0] +; CHECK-NEXT: ret +entry: + %0 = load float, ptr %a, align 4 + %1 = load float, ptr %b, align 4 + %mul = fmul contract nsz float %1, %0 + %2 = load float, ptr %c, align 4 + %add = fadd contract nsz float %mul, %2 + %fneg = fneg contract nsz float %add + store float %fneg, ptr %a, align 4 + ret void +} + +; Don't combine: Missing nsz +define void @fnmadds_contract(ptr %a, ptr %b, ptr %c) { +; CHECK-LABEL: fnmadds_contract: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr s0, [x1] +; CHECK-NEXT: ldr s1, [x0] +; CHECK-NEXT: ldr s2, [x2] +; CHECK-NEXT: fmadd s0, s0, s1, s2 +; CHECK-NEXT: fneg s0, s0 +; CHECK-NEXT: str s0, [x0] +; CHECK-NEXT: ret +entry: + %0 = load float, ptr %a, align 4 + %1 = load float, ptr %b, align 4 + %mul = fmul contract float %1, %0 + %2 = load float, ptr %c, align 4 + %add = fadd contract float %mul, %2 + %fneg = fneg contract float %add + store float %fneg, ptr %a, align 4 + ret void +} + +; Don't combine: Missing contract +define void @fnmadds_nsz(ptr %a, ptr %b, ptr %c) { +; CHECK-LABEL: fnmadds_nsz: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr s0, [x0] +; CHECK-NEXT: ldr s1, [x1] +; CHECK-NEXT: fmul s0, s1, s0 +; CHECK-NEXT: ldr s1, [x2] +; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: fneg s0, s0 +; CHECK-NEXT: str s0, [x0] +; CHECK-NEXT: ret +entry: + %0 = load float, ptr %a, align 4 + %1 = load float, ptr %b, align 4 + %mul = fmul nsz float %1, %0 + %2 = load float, ptr %c, align 4 + %add = fadd nsz float %mul, %2 + %fneg = fneg nsz float %add + store float %fneg, ptr %a, align 4 + ret void +} + +define void @fnmaddd_two_uses(ptr %a, ptr %b, ptr %c, ptr %d) { +; CHECK-LABEL: fnmaddd_two_uses: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldr d0, [x1] +; CHECK-NEXT: ldr d1, [x0] +; CHECK-NEXT: ldr d2, [x2] +; CHECK-NEXT: fmadd d0, d0, d1, d2 +; CHECK-NEXT: fneg d1, d0 +; CHECK-NEXT: str d1, [x0] +; CHECK-NEXT: str d0, [x3] +; CHECK-NEXT: ret +entry: + %0 = load double, ptr %a, align 8 + %1 = load double, ptr %b, align 8 + %mul = fmul fast double %1, %0 + %2 = load double, ptr %c, align 8 + %add = fadd fast double %mul, %2 + %fneg1 = fneg fast double %add + store double %fneg1, ptr %a, align 8 + store double %add, ptr %d, align 8 + ret void +}