Index: lib/Target/ARM/ARMISelDAGToDAG.cpp =================================================================== --- lib/Target/ARM/ARMISelDAGToDAG.cpp +++ lib/Target/ARM/ARMISelDAGToDAG.cpp @@ -284,6 +284,8 @@ /// Replace N with M in CurDAG, in a way that also ensures that M gets /// selected when N would have been selected. void replaceDAGValue(const SDValue &N, SDValue M); + + bool canUseVMLxForwarding(const SDNode &N) const; }; } @@ -416,9 +418,36 @@ } } +/// Check if VMLx accumulator forwarding can be used if the specified SDNode is +/// lowered to a VMLx instruction. +/// The specified SDNode can be lowered to a VMLx instruction if it is either +/// ISD::FSUB or ISD::FADD using a result of ISD::FMUL. +/// Special multiplier accumulator forwarding is used if a multiply-accumulate +/// follows a multiply or another multiply-accumulate, and depends on the +/// result of that first instruction. +bool ARMDAGToDAGISel::canUseVMLxForwarding(const SDNode &N) const { + if (!Subtarget->hasVMLxForwarding()) + return false; + + auto Op0Opcode = N.getOperand(0).getOpcode(); + auto Op1Opcode = N.getOperand(1).getOpcode(); + switch (N.getOpcode()) { + default: + return false; + + case ISD::FSUB: + case ISD::FADD: + if (Op0Opcode == ISD::FMUL && Op1Opcode == ISD::FMUL) + return true; + break; + } + + return false; +} + /// hasNoVMLxHazardUse - Return true if it's desirable to select a FP MLA / MLS /// node. VFP / NEON fp VMLA / VMLS instructions have special RAW hazards (at -/// least on current ARM implementations) which should be avoidded. +/// least on current ARM implementations) which should be avoided. bool ARMDAGToDAGISel::hasNoVMLxHazardUse(SDNode *N) const { if (OptLevel == CodeGenOpt::None) return true; @@ -426,6 +455,9 @@ if (!Subtarget->hasVMLxHazards()) return true; + if (canUseVMLxForwarding(*N)) + return true; + if (!N->hasOneUse()) return false; Index: lib/Target/ARM/MLxExpansionPass.cpp =================================================================== --- lib/Target/ARM/MLxExpansionPass.cpp +++ lib/Target/ARM/MLxExpansionPass.cpp @@ -30,7 +30,7 @@ #define DEBUG_TYPE "mlx-expansion" static cl::opt -ForceExapnd("expand-all-fp-mlx", cl::init(false), cl::Hidden); +ForceExpand("expand-all-fp-mlx", cl::init(false), cl::Hidden); static cl::opt ExpandLimit("expand-limit", cl::init(~0U), cl::Hidden); @@ -50,6 +50,7 @@ private: const ARMBaseInstrInfo *TII; const TargetRegisterInfo *TRI; + const ARMSubtarget *STI; MachineRegisterInfo *MRI; bool isLikeA9; @@ -57,13 +58,15 @@ unsigned MIIdx; MachineInstr* LastMIs[4]; SmallPtrSet IgnoreStall; + SmallPtrSet AccForwarding; void clearStack(); void pushStack(MachineInstr *MI); MachineInstr *getAccDefMI(MachineInstr *MI) const; unsigned getDefReg(MachineInstr *MI) const; bool hasLoopHazard(MachineInstr *MI) const; - bool hasRAWHazard(unsigned Reg, MachineInstr *MI) const; + bool hasRAWHazard(MachineInstr *MI, MachineInstr *NextMI) const; + bool canUseVMLxForwarding(MachineInstr *MI, MachineInstr *AccDef) const; bool FindMLxHazard(MachineInstr *MI); void ExpandFPMLxInstruction(MachineBasicBlock &MBB, MachineInstr *MI, unsigned MulOpc, unsigned AddSubOpc, @@ -182,17 +185,21 @@ return DefMI == MI; } -bool MLxExpansion::hasRAWHazard(unsigned Reg, MachineInstr *MI) const { +bool MLxExpansion::hasRAWHazard(MachineInstr *MI, MachineInstr *NextMI) const { + unsigned Reg = getDefReg(MI); // FIXME: Detect integer instructions properly. - const MCInstrDesc &MCID = MI->getDesc(); + const MCInstrDesc &MCID = NextMI->getDesc(); unsigned Domain = MCID.TSFlags & ARMII::DomainMask; - if (MI->mayStore()) + if (NextMI->mayStore()) return false; unsigned Opcode = MCID.getOpcode(); if (Opcode == ARM::VMOVRS || Opcode == ARM::VMOVRRD) return false; - if ((Domain & ARMII::DomainVFP) || (Domain & ARMII::DomainNEON)) - return MI->readsRegister(Reg, TRI); + if (Domain & ARMII::DomainNEON) + return NextMI->readsRegister(Reg, TRI); + else if (Domain & ARMII::DomainVFP) + return NextMI->readsRegister(Reg, TRI) && !AccForwarding.count(MI); + return false; } @@ -210,14 +217,70 @@ } } +/// Check if VMLx accumulator forwarding can be used from the instruction +/// AccDef defining the accumulator to the VMLx instruction MI using it. +/// Special multiplier accumulator forwarding is used if a multiply-accumulate +/// follows a multiply or another multiply-accumulate, and depends on the +/// result of that first instruction. +bool MLxExpansion::canUseVMLxForwarding(MachineInstr *MI, + MachineInstr *AccDef) const { + assert(STI); + assert(MI); + assert(AccDef); + assert(TII->isFpMLxInstruction(MI->getOpcode())); + + if (!STI->hasVMLxForwarding()) + return false; + + const auto AccDefOpcode = AccDef->getOpcode(); + switch (MI->getOpcode()) { + default: + return false; + + case ARM::VMLAS: + case ARM::VMLSS: + switch (AccDefOpcode) { + default: + return false; + + case ARM::VMLAS: + case ARM::VMLSS: + case ARM::VMULS: + return true; + } + break; + + case ARM::VMLAD: + case ARM::VMLSD: + switch (AccDefOpcode) { + default: + return false; + + case ARM::VMLAD: + case ARM::VMLSD: + case ARM::VMULD: + return true; + } + break; + } + + return false; +} + bool MLxExpansion::FindMLxHazard(MachineInstr *MI) { if (NumExpand >= ExpandLimit) return false; - if (ForceExapnd) + if (ForceExpand) return true; MachineInstr *DefMI = getAccDefMI(MI); + + if (canUseVMLxForwarding(MI, DefMI)) { + AccForwarding.insert(DefMI); + return false; + } + if (TII->isFpMLxInstruction(DefMI->getOpcode())) { // r0 = vmla // r3 = vmla r0, r1, r2 @@ -259,7 +322,7 @@ } // Look for VMLx RAW hazard. - if (i <= Limit2 && hasRAWHazard(getDefReg(MI), NextMI)) + if (i <= Limit2 && hasRAWHazard(MI, NextMI)) return true; } @@ -330,6 +393,7 @@ clearStack(); IgnoreStall.clear(); + AccForwarding.clear(); unsigned Skip = 0; MachineBasicBlock::reverse_iterator MII = MBB.rbegin(), E = MBB.rend(); @@ -377,7 +441,7 @@ TII = static_cast(Fn.getSubtarget().getInstrInfo()); TRI = Fn.getSubtarget().getRegisterInfo(); MRI = &Fn.getRegInfo(); - const ARMSubtarget *STI = &Fn.getSubtarget(); + STI = &Fn.getSubtarget(); if (!STI->expandMLx()) return false; isLikeA9 = STI->isLikeA9() || STI->isSwift(); Index: test/CodeGen/ARM/fmacs.ll =================================================================== --- test/CodeGen/ARM/fmacs.ll +++ test/CodeGen/ARM/fmacs.ll @@ -89,13 +89,11 @@ ; A9-LABEL: t5: ; A9: vmla.f32 -; A9: vmul.f32 -; A9: vadd.f32 +; A9: vmla.f32 ; HARD-LABEL: t5: ; HARD: vmla.f32 s4, s0, s1 -; HARD: vmul.f32 s0, s2, s3 -; HARD: vadd.f32 s0, s4, s0 +; HARD: vmla.f32 s4, s2, s3 %0 = fmul float %a, %b %1 = fadd float %e, %0 %2 = fmul float %c, %d Index: test/CodeGen/ARM/vmlx-fwd.ll =================================================================== --- /dev/null +++ test/CodeGen/ARM/vmlx-fwd.ll @@ -0,0 +1,57 @@ +; RUN: llc -mtriple=arm-eabi -mcpu=cortex-a9 %s -o - | FileCheck %s -check-prefix=SCALAR +; RUN: llc -mtriple=arm-eabi -mcpu=cortex-a9 %s -o - | FileCheck %s -check-prefix=VECTOR +; RUN: llc -mtriple=arm-eabi -mcpu=swift %s -o - | FileCheck %s -check-prefix=SWIFT + +; SWIFT-LABEL: test1: +; SWIFT-NOT: vml{{.*}} + +; SCALAR-LABEL: test1: +define double @test1(double %a, double %b, double %c, double %d) { + %1 = fmul double %a, %c + %2 = fmul double %b, %d + %3 = fsub double %1, %2 + + %4 = fmul double %a, %d + %5 = fmul double %b, %c + %6 = fadd double %5, %4 +; SCALAR: vml{{[as]}}.f64 {{.*}} +; SCALAR: vml{{[as]}}.f64 {{.*}} + + %7 = fsub double %3, %6 + + ret double %7 +} + +; SCALAR-LABEL: test2: +define float @test2(float %a, float %b, float %c, float %d) { + %1 = fmul float %a, %c + %2 = fmul float %b, %d + %3 = fsub float %1, %2 + + %4 = fmul float %a, %d + %5 = fmul float %b, %c + %6 = fadd float %5, %4 +; SCALAR: vml{{[as]}}.f32 {{.*}} +; SCALAR: vml{{[as]}}.f32 {{.*}} + + %7 = fsub float %3, %6 + + ret float %7 +} + +; VECTOR-LABEL: test3 +; VECTOR-NOT: vml{{.*}} +define <2 x float> @test3(<2 x float> %a, <2 x float> %b, <2 x float> %c, <2 x float> %d) { + %1 = fmul <2 x float> %a, %c + %2 = fmul <2 x float> %b, %d + %3 = fsub <2 x float> %1, %2 + + %4 = fmul <2 x float> %a, %d + %5 = fmul <2 x float> %b, %c + %6 = fadd <2 x float> %5, %4 + + %7 = fsub <2 x float> %6, %3 + + ret <2 x float> %7 +} +