diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -153,7 +153,18 @@ FMLSv4f32_OP1, FMLSv4f32_OP2, FMLSv4i32_indexed_OP1, - FMLSv4i32_indexed_OP2 + FMLSv4i32_indexed_OP2, + + FMULv2i32_indexed_OP1, + FMULv2i32_indexed_OP2, + FMULv2i64_indexed_OP1, + FMULv2i64_indexed_OP2, + FMULv4i16_indexed_OP1, + FMULv4i16_indexed_OP2, + FMULv4i32_indexed_OP1, + FMULv4i32_indexed_OP2, + FMULv8i16_indexed_OP1, + FMULv8i16_indexed_OP2, }; } // end namespace llvm diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -4525,6 +4525,60 @@ return Found; } +static bool getFMULPatterns(MachineInstr &Root, + SmallVectorImpl &Patterns) { + MachineBasicBlock &MBB = *Root.getParent(); + bool Found = false; + + auto Match = [&](unsigned Opcode, int Operand, + MachineCombinerPattern Pattern) -> bool { + MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo(); + MachineOperand &MO = Root.getOperand(Operand); + MachineInstr *MI = nullptr; + if (MO.isReg() && Register::isVirtualRegister(MO.getReg())) + MI = MRI.getUniqueVRegDef(MO.getReg()); + if (MI && MI->getOpcode() == Opcode) { + Patterns.push_back(Pattern); + return true; + } + return false; + }; + + typedef MachineCombinerPattern MCP; + + switch (Root.getOpcode()) { + default: + return false; + case AArch64::FMULv2f32: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg()); + Found = Match(AArch64::DUPv2i32lane, 1, MCP::FMULv2i32_indexed_OP1); + Found |= Match(AArch64::DUPv2i32lane, 2, MCP::FMULv2i32_indexed_OP2); + break; + case AArch64::FMULv2f64: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg()); + Found = Match(AArch64::DUPv2i64lane, 1, MCP::FMULv2i64_indexed_OP1); + Found |= Match(AArch64::DUPv2i64lane, 2, MCP::FMULv2i64_indexed_OP2); + break; + case AArch64::FMULv4f16: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg()); + Found = Match(AArch64::DUPv4i16lane, 1, MCP::FMULv4i16_indexed_OP1); + Found |= Match(AArch64::DUPv4i16lane, 2, MCP::FMULv4i16_indexed_OP2); + break; + case AArch64::FMULv4f32: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg()); + Found = Match(AArch64::DUPv4i32lane, 1, MCP::FMULv4i32_indexed_OP1); + Found |= Match(AArch64::DUPv4i32lane, 2, MCP::FMULv4i32_indexed_OP2); + break; + case AArch64::FMULv8f16: + assert(Root.getOperand(1).isReg() && Root.getOperand(2).isReg()); + Found = Match(AArch64::DUPv8i16lane, 1, MCP::FMULv8i16_indexed_OP1); + Found |= Match(AArch64::DUPv8i16lane, 2, MCP::FMULv8i16_indexed_OP2); + break; + } + + return Found; +} + /// Return true when a code sequence can improve throughput. It /// should be called only for instructions in loops. /// \param Pattern - combiner pattern @@ -4588,6 +4642,16 @@ case MachineCombinerPattern::FMLSv2f64_OP2: case MachineCombinerPattern::FMLSv4i32_indexed_OP2: case MachineCombinerPattern::FMLSv4f32_OP2: + case MachineCombinerPattern::FMULv2i32_indexed_OP1: + case MachineCombinerPattern::FMULv2i32_indexed_OP2: + case MachineCombinerPattern::FMULv2i64_indexed_OP1: + case MachineCombinerPattern::FMULv2i64_indexed_OP2: + case MachineCombinerPattern::FMULv4i16_indexed_OP1: + case MachineCombinerPattern::FMULv4i16_indexed_OP2: + case MachineCombinerPattern::FMULv4i32_indexed_OP1: + case MachineCombinerPattern::FMULv4i32_indexed_OP2: + case MachineCombinerPattern::FMULv8i16_indexed_OP1: + case MachineCombinerPattern::FMULv8i16_indexed_OP2: case MachineCombinerPattern::MULADDv8i8_OP1: case MachineCombinerPattern::MULADDv8i8_OP2: case MachineCombinerPattern::MULADDv16i8_OP1: @@ -4644,6 +4708,8 @@ if (getMaddPatterns(Root, Patterns)) return true; // Floating point patterns + if (getFMULPatterns(Root, Patterns)) + return true; if (getFMAPatterns(Root, Patterns)) return true; @@ -4732,6 +4798,31 @@ return MUL; } +static MachineInstr *genIndexedMultiply( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl &InsInstrs, + unsigned IdxDupOp, unsigned MulOpc, const TargetRegisterClass *RC) { + assert(IdxDupOp == 1 || IdxDupOp == 2); + + MachineInstr *Dup = MRI.getUniqueVRegDef(Root.getOperand(IdxDupOp).getReg()); + Register DupSrcReg = Dup->getOperand(1).getReg(); + Register DupSrcLane = Dup->getOperand(2).getImm(); + + unsigned IdxMulOp = IdxDupOp == 1 ? 2 : 1; + MachineOperand &MulOp = Root.getOperand(IdxMulOp); + + Register ResultReg = Root.getOperand(0).getReg(); + + MachineInstrBuilder MIB; + MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MulOpc), ResultReg) + .add(MulOp) + .addReg(DupSrcReg) + .addImm(DupSrcLane); + + InsInstrs.push_back(MIB); + return &Root; +} + /// genFusedMultiplyAcc - Helper to generate fused multiply accumulate /// instructions. /// @@ -5690,12 +5781,58 @@ } break; } + case MachineCombinerPattern::FMULv2i32_indexed_OP1: + case MachineCombinerPattern::FMULv2i32_indexed_OP2: { + RC = &AArch64::FPR64RegClass; + unsigned IdxDupOp = + (Pattern == MachineCombinerPattern::FMULv2i32_indexed_OP1) ? 1 : 2; + Opc = AArch64::FMULv2i32_indexed; + genIndexedMultiply(MF, MRI, TII, Root, InsInstrs, IdxDupOp, Opc, RC); + break; + } + case MachineCombinerPattern::FMULv2i64_indexed_OP1: + case MachineCombinerPattern::FMULv2i64_indexed_OP2: { + RC = &AArch64::FPR128RegClass; + unsigned IdxDupOp = + (Pattern == MachineCombinerPattern::FMULv2i64_indexed_OP1) ? 1 : 2; + Opc = AArch64::FMULv2i64_indexed; + genIndexedMultiply(MF, MRI, TII, Root, InsInstrs, IdxDupOp, Opc, RC); + break; + } + case MachineCombinerPattern::FMULv4i16_indexed_OP1: + case MachineCombinerPattern::FMULv4i16_indexed_OP2: { + RC = &AArch64::FPR64RegClass; + unsigned IdxDupOp = + (Pattern == MachineCombinerPattern::FMULv4i16_indexed_OP1) ? 1 : 2; + Opc = AArch64::FMULv4i16_indexed; + genIndexedMultiply(MF, MRI, TII, Root, InsInstrs, IdxDupOp, Opc, RC); + break; + } + case MachineCombinerPattern::FMULv4i32_indexed_OP1: + case MachineCombinerPattern::FMULv4i32_indexed_OP2: { + RC = &AArch64::FPR128RegClass; + unsigned IdxDupOp = + (Pattern == MachineCombinerPattern::FMULv4i32_indexed_OP1) ? 1 : 2; + Opc = AArch64::FMULv4i32_indexed; + genIndexedMultiply(MF, MRI, TII, Root, InsInstrs, IdxDupOp, Opc, RC); + break; + } + case MachineCombinerPattern::FMULv8i16_indexed_OP1: + case MachineCombinerPattern::FMULv8i16_indexed_OP2: { + RC = &AArch64::FPR128RegClass; + unsigned IdxDupOp = + (Pattern == MachineCombinerPattern::FMULv8i16_indexed_OP1) ? 1 : 2; + Opc = AArch64::FMULv8i16_indexed; + genIndexedMultiply(MF, MRI, TII, Root, InsInstrs, IdxDupOp, Opc, RC); + break; + } } // end switch (Pattern) // Record MUL and ADD/SUB for deletion // FIXME: This assertion fails in CodeGen/AArch64/tailmerging_in_mbp.ll and // CodeGen/AArch64/urem-seteq-nonzero.ll. // assert(MUL && "MUL was never set"); - DelInstrs.push_back(MUL); + if (MUL) + DelInstrs.push_back(MUL); DelInstrs.push_back(&Root); } diff --git a/llvm/test/CodeGen/AArch64/arm64-fma-combines.ll b/llvm/test/CodeGen/AArch64/arm64-fma-combines.ll --- a/llvm/test/CodeGen/AArch64/arm64-fma-combines.ll +++ b/llvm/test/CodeGen/AArch64/arm64-fma-combines.ll @@ -134,3 +134,49 @@ for.end: ; preds = %for.body ret void } + +define void @indexed_2s(<2 x float> %shuf, <2 x float> %mu, <2 x float> %ad, <2 x float>* %ret) { +; CHECK-LABEL: %entry +; CHECK: fmla.2s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0] +; +entry: + %shuffle = shufflevector <2 x float> %shuf, <2 x float> undef, <2 x i32> zeroinitializer + br label %for.cond + +for.cond: + %mul = fmul fast <2 x float> %mu, %shuffle + %add = fadd fast <2 x float> %mul, %ad + store <2 x float> %add, <2 x float>* %ret, align 16 + br label %for.cond +} + +define void @indexed_2d(<2 x double> %shuf, <2 x double> %mu, <2 x double> %ad, <2 x double>* %ret) { +; CHECK-LABEL: %entry +; CHECK: fmla.2d {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0] +; +entry: + %shuffle = shufflevector <2 x double> %shuf, <2 x double> undef, <2 x i32> zeroinitializer + br label %for.cond + +for.cond: + %mul = fmul fast <2 x double> %mu, %shuffle + %add = fadd fast <2 x double> %mul, %ad + store <2 x double> %add, <2 x double>* %ret, align 16 + br label %for.cond +} + +define void @indexed_4s(<4 x float> %shuf, <4 x float> %mu, <4 x float> %ad, <4 x float>* %ret) { +; CHECK-LABEL: %entry +; CHECK: fmla.4s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0] +; +entry: + %shuffle = shufflevector <4 x float> %shuf, <4 x float> undef, <4 x i32> zeroinitializer + br label %for.cond + +for.cond: + %mul = fmul fast <4 x float> %mu, %shuffle + %add = fadd fast <4 x float> %mul, %ad + store <4 x float> %add, <4 x float>* %ret, align 16 + br label %for.cond +} +