diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -38,6 +38,51 @@ MULSUBX_OP2, MULADDXI_OP1, MULSUBXI_OP1, + // NEON integers vectors + MULADDv8i8_OP1, + MULADDv8i8_OP2, + MULADDv16i8_OP1, + MULADDv16i8_OP2, + MULADDv4i16_OP1, + MULADDv4i16_OP2, + MULADDv8i16_OP1, + MULADDv8i16_OP2, + MULADDv2i32_OP1, + MULADDv2i32_OP2, + MULADDv4i32_OP1, + MULADDv4i32_OP2, + + MULSUBv8i8_OP1, + MULSUBv8i8_OP2, + MULSUBv16i8_OP1, + MULSUBv16i8_OP2, + MULSUBv4i16_OP1, + MULSUBv4i16_OP2, + MULSUBv8i16_OP1, + MULSUBv8i16_OP2, + MULSUBv2i32_OP1, + MULSUBv2i32_OP2, + MULSUBv4i32_OP1, + MULSUBv4i32_OP2, + + MULADDv4i16_indexed_OP1, + MULADDv4i16_indexed_OP2, + MULADDv8i16_indexed_OP1, + MULADDv8i16_indexed_OP2, + MULADDv2i32_indexed_OP1, + MULADDv2i32_indexed_OP2, + MULADDv4i32_indexed_OP1, + MULADDv4i32_indexed_OP2, + + MULSUBv4i16_indexed_OP1, + MULSUBv4i16_indexed_OP2, + MULSUBv8i16_indexed_OP1, + MULSUBv8i16_indexed_OP2, + MULSUBv2i32_indexed_OP1, + MULSUBv2i32_indexed_OP2, + MULSUBv4i32_indexed_OP1, + MULSUBv4i32_indexed_OP2, + // Floating Point FMULADDH_OP1, FMULADDH_OP2, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -3571,6 +3571,18 @@ // Note: MSUB Wd,Wn,Wm,Wi -> Wd = Wi - WnxWm, not Wd=WnxWm - Wi. case AArch64::SUBXri: case AArch64::SUBSXri: + case AArch64::ADDv8i8: + case AArch64::ADDv16i8: + case AArch64::ADDv4i16: + case AArch64::ADDv8i16: + case AArch64::ADDv2i32: + case AArch64::ADDv4i32: + case AArch64::SUBv8i8: + case AArch64::SUBv16i8: + case AArch64::SUBv4i16: + case AArch64::SUBv8i16: + case AArch64::SUBv2i32: + case AArch64::SUBv4i32: return true; default: break; @@ -3713,6 +3725,13 @@ } }; + auto setVFound = [&](int Opcode, int Operand, MachineCombinerPattern Pattern) { + if (canCombine(MBB, Root.getOperand(Operand), Opcode)) { + Patterns.push_back(Pattern); + Found = true; + } + }; + typedef MachineCombinerPattern MCP; switch (Opc) { @@ -3748,6 +3767,70 @@ case AArch64::SUBXri: setFound(AArch64::MADDXrrr, 1, AArch64::XZR, MCP::MULSUBXI_OP1); break; + case AArch64::ADDv8i8: + setVFound(AArch64::MULv8i8, 1, MCP::MULADDv8i8_OP1); + setVFound(AArch64::MULv8i8, 2, MCP::MULADDv8i8_OP2); + break; + case AArch64::ADDv16i8: + setVFound(AArch64::MULv16i8, 1, MCP::MULADDv16i8_OP1); + setVFound(AArch64::MULv16i8, 2, MCP::MULADDv16i8_OP2); + break; + case AArch64::ADDv4i16: + setVFound(AArch64::MULv4i16, 1, MCP::MULADDv4i16_OP1); + setVFound(AArch64::MULv4i16, 2, MCP::MULADDv4i16_OP2); + setVFound(AArch64::MULv4i16_indexed, 1, MCP::MULADDv4i16_indexed_OP1); + setVFound(AArch64::MULv4i16_indexed, 2, MCP::MULADDv4i16_indexed_OP2); + break; + case AArch64::ADDv8i16: + setVFound(AArch64::MULv8i16, 1, MCP::MULADDv8i16_OP1); + setVFound(AArch64::MULv8i16, 2, MCP::MULADDv8i16_OP2); + setVFound(AArch64::MULv8i16_indexed, 1, MCP::MULADDv8i16_indexed_OP1); + setVFound(AArch64::MULv8i16_indexed, 2, MCP::MULADDv8i16_indexed_OP2); + break; + case AArch64::ADDv2i32: + setVFound(AArch64::MULv2i32, 1, MCP::MULADDv2i32_OP1); + setVFound(AArch64::MULv2i32, 2, MCP::MULADDv2i32_OP2); + setVFound(AArch64::MULv2i32_indexed, 1, MCP::MULADDv2i32_indexed_OP1); + setVFound(AArch64::MULv2i32_indexed, 2, MCP::MULADDv2i32_indexed_OP2); + break; + case AArch64::ADDv4i32: + setVFound(AArch64::MULv4i32, 1, MCP::MULADDv4i32_OP1); + setVFound(AArch64::MULv4i32, 2, MCP::MULADDv4i32_OP2); + setVFound(AArch64::MULv4i32_indexed, 1, MCP::MULADDv4i32_indexed_OP1); + setVFound(AArch64::MULv4i32_indexed, 2, MCP::MULADDv4i32_indexed_OP2); + break; + case AArch64::SUBv8i8: + setVFound(AArch64::MULv8i8, 1, MCP::MULSUBv8i8_OP1); + setVFound(AArch64::MULv8i8, 2, MCP::MULSUBv8i8_OP2); + break; + case AArch64::SUBv16i8: + setVFound(AArch64::MULv16i8, 1, MCP::MULSUBv16i8_OP1); + setVFound(AArch64::MULv16i8, 2, MCP::MULSUBv16i8_OP2); + break; + case AArch64::SUBv4i16: + setVFound(AArch64::MULv4i16, 1, MCP::MULSUBv4i16_OP1); + setVFound(AArch64::MULv4i16, 2, MCP::MULSUBv4i16_OP2); + setVFound(AArch64::MULv4i16_indexed, 1, MCP::MULSUBv4i16_indexed_OP1); + setVFound(AArch64::MULv4i16_indexed, 2, MCP::MULSUBv4i16_indexed_OP2); + break; + case AArch64::SUBv8i16: + setVFound(AArch64::MULv8i16, 1, MCP::MULSUBv8i16_OP1); + setVFound(AArch64::MULv8i16, 2, MCP::MULSUBv8i16_OP2); + setVFound(AArch64::MULv8i16_indexed, 1, MCP::MULSUBv8i16_indexed_OP1); + setVFound(AArch64::MULv8i16_indexed, 2, MCP::MULSUBv8i16_indexed_OP2); + break; + case AArch64::SUBv2i32: + setVFound(AArch64::MULv2i32, 1, MCP::MULSUBv2i32_OP1); + setVFound(AArch64::MULv2i32, 2, MCP::MULSUBv2i32_OP2); + setVFound(AArch64::MULv2i32_indexed, 1, MCP::MULSUBv2i32_indexed_OP1); + setVFound(AArch64::MULv2i32_indexed, 2, MCP::MULSUBv2i32_indexed_OP2); + break; + case AArch64::SUBv4i32: + setVFound(AArch64::MULv4i32, 1, MCP::MULSUBv4i32_OP1); + setVFound(AArch64::MULv4i32, 2, MCP::MULSUBv4i32_OP2); + setVFound(AArch64::MULv4i32_indexed, 1, MCP::MULSUBv4i32_indexed_OP1); + setVFound(AArch64::MULv4i32_indexed, 2, MCP::MULSUBv4i32_indexed_OP2); + break; } return Found; } @@ -3960,6 +4043,47 @@ case MachineCombinerPattern::FMLSv2f64_OP2: case MachineCombinerPattern::FMLSv4i32_indexed_OP2: case MachineCombinerPattern::FMLSv4f32_OP2: + case MachineCombinerPattern::MULADDv8i8_OP1: + case MachineCombinerPattern::MULADDv8i8_OP2: + case MachineCombinerPattern::MULADDv16i8_OP1: + case MachineCombinerPattern::MULADDv16i8_OP2: + case MachineCombinerPattern::MULADDv4i16_OP1: + case MachineCombinerPattern::MULADDv4i16_OP2: + case MachineCombinerPattern::MULADDv8i16_OP1: + case MachineCombinerPattern::MULADDv8i16_OP2: + case MachineCombinerPattern::MULADDv2i32_OP1: + case MachineCombinerPattern::MULADDv2i32_OP2: + case MachineCombinerPattern::MULADDv4i32_OP1: + case MachineCombinerPattern::MULADDv4i32_OP2: + case MachineCombinerPattern::MULSUBv8i8_OP1: + case MachineCombinerPattern::MULSUBv8i8_OP2: + case MachineCombinerPattern::MULSUBv16i8_OP1: + case MachineCombinerPattern::MULSUBv16i8_OP2: + case MachineCombinerPattern::MULSUBv4i16_OP1: + case MachineCombinerPattern::MULSUBv4i16_OP2: + case MachineCombinerPattern::MULSUBv8i16_OP1: + case MachineCombinerPattern::MULSUBv8i16_OP2: + case MachineCombinerPattern::MULSUBv2i32_OP1: + case MachineCombinerPattern::MULSUBv2i32_OP2: + case MachineCombinerPattern::MULSUBv4i32_OP1: + case MachineCombinerPattern::MULSUBv4i32_OP2: + case MachineCombinerPattern::MULADDv4i16_indexed_OP1: + case MachineCombinerPattern::MULADDv4i16_indexed_OP2: + case MachineCombinerPattern::MULADDv8i16_indexed_OP1: + case MachineCombinerPattern::MULADDv8i16_indexed_OP2: + case MachineCombinerPattern::MULADDv2i32_indexed_OP1: + case MachineCombinerPattern::MULADDv2i32_indexed_OP2: + case MachineCombinerPattern::MULADDv4i32_indexed_OP1: + case MachineCombinerPattern::MULADDv4i32_indexed_OP2: + case MachineCombinerPattern::MULSUBv4i16_indexed_OP1: + case MachineCombinerPattern::MULSUBv4i16_indexed_OP2: + case MachineCombinerPattern::MULSUBv8i16_indexed_OP1: + case MachineCombinerPattern::MULSUBv8i16_indexed_OP2: + case MachineCombinerPattern::MULSUBv2i32_indexed_OP1: + case MachineCombinerPattern::MULSUBv2i32_indexed_OP2: + case MachineCombinerPattern::MULSUBv4i32_indexed_OP1: + case MachineCombinerPattern::MULSUBv4i32_indexed_OP2: + // TODO: add SIMD MULADD etc. return true; } // end switch (Pattern) return false; @@ -4063,6 +4187,30 @@ return MUL; } +/// genFusedMultiplyAcc - Helper to generate fused multiply accumulate +/// instructions. +/// +/// \see genFusedMultiply +static MachineInstr *genFusedMultiplyAcc( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl &InsInstrs, + unsigned IdxMulOpd, unsigned MaddOpc, const TargetRegisterClass *RC) { + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Accumulator); +} + +/// genFusedMultiplyIdx - Helper to generate fused multiply accumulate +/// instructions. +/// +/// \see genFusedMultiply +static MachineInstr *genFusedMultiplyIdx( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl &InsInstrs, + unsigned IdxMulOpd, unsigned MaddOpc, const TargetRegisterClass *RC) { + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Indexed); +} + /// genMaddR - Generate madd instruction and combine mul and add using /// an extra virtual register /// Example - an ADD intermediate needs to be stored in a register: @@ -4302,6 +4450,211 @@ } break; } + + case MachineCombinerPattern::MULADDv8i8_OP1: + Opc = AArch64::MLAv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i8_OP2: + Opc = AArch64::MLAv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv16i8_OP1: + Opc = AArch64::MLAv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv16i8_OP2: + Opc = AArch64::MLAv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i16_OP1: + Opc = AArch64::MLAv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i16_OP2: + Opc = AArch64::MLAv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_OP1: + Opc = AArch64::MLAv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_OP2: + Opc = AArch64::MLAv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_OP1: + Opc = AArch64::MLAv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_OP2: + Opc = AArch64::MLAv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_OP1: + Opc = AArch64::MLAv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_OP2: + Opc = AArch64::MLAv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::MULSUBv8i8_OP1: + Opc = AArch64::MLSv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i8_OP2: + Opc = AArch64::MLSv8i8; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv16i8_OP1: + Opc = AArch64::MLSv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv16i8_OP2: + Opc = AArch64::MLSv16i8; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i16_OP1: + Opc = AArch64::MLSv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i16_OP2: + Opc = AArch64::MLSv4i16; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i16_OP1: + Opc = AArch64::MLSv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i16_OP2: + Opc = AArch64::MLSv8i16; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv2i32_OP1: + Opc = AArch64::MLSv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv2i32_OP2: + Opc = AArch64::MLSv2i32; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i32_OP1: + Opc = AArch64::MLSv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i32_OP2: + Opc = AArch64::MLSv4i32; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::MULADDv4i16_indexed_OP1: + Opc = AArch64::MLAv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i16_indexed_OP2: + Opc = AArch64::MLAv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_indexed_OP1: + Opc = AArch64::MLAv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv8i16_indexed_OP2: + Opc = AArch64::MLAv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_indexed_OP1: + Opc = AArch64::MLAv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv2i32_indexed_OP2: + Opc = AArch64::MLAv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_indexed_OP1: + Opc = AArch64::MLAv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULADDv4i32_indexed_OP2: + Opc = AArch64::MLAv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + + case MachineCombinerPattern::MULSUBv4i16_indexed_OP1: + Opc = AArch64::MLSv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i16_indexed_OP2: + Opc = AArch64::MLSv4i16_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i16_indexed_OP1: + Opc = AArch64::MLSv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv8i16_indexed_OP2: + Opc = AArch64::MLSv8i16_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv2i32_indexed_OP1: + Opc = AArch64::MLSv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv2i32_indexed_OP2: + Opc = AArch64::MLSv2i32_indexed; + RC = &AArch64::FPR64RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i32_indexed_OP1: + Opc = AArch64::MLSv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + break; + case MachineCombinerPattern::MULSUBv4i32_indexed_OP2: + Opc = AArch64::MLSv4i32_indexed; + RC = &AArch64::FPR128RegClass; + MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); + break; + // Floating Point Support case MachineCombinerPattern::FMULADDH_OP1: Opc = AArch64::FMADDHrrr; diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -3793,10 +3793,11 @@ defm FRECPS : SIMDThreeSameVectorFP<0,0,0b111,"frecps", int_aarch64_neon_frecps>; defm FRSQRTS : SIMDThreeSameVectorFP<0,1,0b111,"frsqrts", int_aarch64_neon_frsqrts>; defm FSUB : SIMDThreeSameVectorFP<0,1,0b010,"fsub", fsub>; -defm MLA : SIMDThreeSameVectorBHSTied<0, 0b10010, "mla", - TriOpFrag<(add node:$LHS, (mul node:$MHS, node:$RHS))> >; -defm MLS : SIMDThreeSameVectorBHSTied<1, 0b10010, "mls", - TriOpFrag<(sub node:$LHS, (mul node:$MHS, node:$RHS))> >; + +// MLA and MLS are generated in MachineCombine +defm MLA : SIMDThreeSameVectorBHSTied<0, 0b10010, "mla", null_frag>; +defm MLS : SIMDThreeSameVectorBHSTied<1, 0b10010, "mls", null_frag>; + defm MUL : SIMDThreeSameVectorBHS<0, 0b10011, "mul", mul>; defm PMUL : SIMDThreeSameVectorB<1, 0b10011, "pmul", int_aarch64_neon_pmul>; defm SABA : SIMDThreeSameVectorBHSTied<0, 0b01111, "saba", @@ -5526,10 +5527,11 @@ defm SQDMULH : SIMDIndexedHS<0, 0b1100, "sqdmulh", int_aarch64_neon_sqdmulh>; defm SQRDMULH : SIMDIndexedHS<0, 0b1101, "sqrdmulh", int_aarch64_neon_sqrdmulh>; -defm MLA : SIMDVectorIndexedHSTied<1, 0b0000, "mla", - TriOpFrag<(add node:$LHS, (mul node:$MHS, node:$RHS))>>; -defm MLS : SIMDVectorIndexedHSTied<1, 0b0100, "mls", - TriOpFrag<(sub node:$LHS, (mul node:$MHS, node:$RHS))>>; + +// Generated by MachineCombine +defm MLA : SIMDVectorIndexedHSTied<1, 0b0000, "mla", null_frag>; +defm MLS : SIMDVectorIndexedHSTied<1, 0b0100, "mls", null_frag>; + defm MUL : SIMDVectorIndexedHS<0, 0b1000, "mul", mul>; defm SMLAL : SIMDVectorIndexedLongSDTied<0, 0b0010, "smlal", TriOpFrag<(add node:$LHS, (int_aarch64_neon_smull node:$MHS, node:$RHS))>>; diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/select-with-no-legality-check.mir b/llvm/test/CodeGen/AArch64/GlobalISel/select-with-no-legality-check.mir --- a/llvm/test/CodeGen/AArch64/GlobalISel/select-with-no-legality-check.mir +++ b/llvm/test/CodeGen/AArch64/GlobalISel/select-with-no-legality-check.mir @@ -1433,8 +1433,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr64 = COPY $d2 ; CHECK: [[COPY1:%[0-9]+]]:fpr64 = COPY $d1 ; CHECK: [[COPY2:%[0-9]+]]:fpr64 = COPY $d0 - ; CHECK: [[MLAv8i8_:%[0-9]+]]:fpr64 = MLAv8i8 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv8i8_]] + ; CHECK: [[MULv8i8_:%[0-9]+]]:fpr64 = MULv8i8 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv8i8_:%[0-9]+]]:fpr64 = ADDv8i8 [[MULv8i8_]], [[COPY2]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv8i8_]] %4:fpr(<8 x s8>) = COPY $d2 %3:fpr(<8 x s8>) = COPY $d1 %2:fpr(<8 x s8>) = COPY $d0 @@ -1468,8 +1469,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr128 = COPY $q2 ; CHECK: [[COPY1:%[0-9]+]]:fpr128 = COPY $q1 ; CHECK: [[COPY2:%[0-9]+]]:fpr128 = COPY $q0 - ; CHECK: [[MLAv16i8_:%[0-9]+]]:fpr128 = MLAv16i8 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv16i8_]] + ; CHECK: [[MULv16i8_:%[0-9]+]]:fpr128 = MULv16i8 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv16i8_:%[0-9]+]]:fpr128 = ADDv16i8 [[MULv16i8_]], [[COPY2]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv16i8_]] %4:fpr(<16 x s8>) = COPY $q2 %3:fpr(<16 x s8>) = COPY $q1 %2:fpr(<16 x s8>) = COPY $q0 @@ -1503,8 +1505,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr64 = COPY $d2 ; CHECK: [[COPY1:%[0-9]+]]:fpr64 = COPY $d1 ; CHECK: [[COPY2:%[0-9]+]]:fpr64 = COPY $d0 - ; CHECK: [[MLAv4i16_:%[0-9]+]]:fpr64 = MLAv4i16 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv4i16_]] + ; CHECK: [[MULv4i16_:%[0-9]+]]:fpr64 = MULv4i16 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv4i16_:%[0-9]+]]:fpr64 = ADDv4i16 [[MULv4i16_]], [[COPY2]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv4i16_]] %4:fpr(<4 x s16>) = COPY $d2 %3:fpr(<4 x s16>) = COPY $d1 %2:fpr(<4 x s16>) = COPY $d0 @@ -1538,8 +1541,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr128 = COPY $q2 ; CHECK: [[COPY1:%[0-9]+]]:fpr128 = COPY $q1 ; CHECK: [[COPY2:%[0-9]+]]:fpr128 = COPY $q0 - ; CHECK: [[MLAv8i16_:%[0-9]+]]:fpr128 = MLAv8i16 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv8i16_]] + ; CHECK: [[MULv8i16_:%[0-9]+]]:fpr128 = MULv8i16 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv8i16_:%[0-9]+]]:fpr128 = ADDv8i16 [[MULv8i16_]], [[COPY2]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv8i16_]] %4:fpr(<8 x s16>) = COPY $q2 %3:fpr(<8 x s16>) = COPY $q1 %2:fpr(<8 x s16>) = COPY $q0 @@ -1759,8 +1763,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr64 = COPY $d2 ; CHECK: [[COPY1:%[0-9]+]]:fpr64 = COPY $d1 ; CHECK: [[COPY2:%[0-9]+]]:fpr64 = COPY $d0 - ; CHECK: [[MLAv8i8_:%[0-9]+]]:fpr64 = MLAv8i8 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv8i8_]] + ; CHECK: [[MULv8i8_:%[0-9]+]]:fpr64 = MULv8i8 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv8i8_:%[0-9]+]]:fpr64 = ADDv8i8 [[COPY2]], [[MULv8i8_]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv8i8_]] %4:fpr(<8 x s8>) = COPY $d2 %3:fpr(<8 x s8>) = COPY $d1 %2:fpr(<8 x s8>) = COPY $d0 @@ -1794,8 +1799,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr128 = COPY $q2 ; CHECK: [[COPY1:%[0-9]+]]:fpr128 = COPY $q1 ; CHECK: [[COPY2:%[0-9]+]]:fpr128 = COPY $q0 - ; CHECK: [[MLAv16i8_:%[0-9]+]]:fpr128 = MLAv16i8 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv16i8_]] + ; CHECK: [[MULv16i8_:%[0-9]+]]:fpr128 = MULv16i8 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv16i8_:%[0-9]+]]:fpr128 = ADDv16i8 [[COPY2]], [[MULv16i8_]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv16i8_]] %4:fpr(<16 x s8>) = COPY $q2 %3:fpr(<16 x s8>) = COPY $q1 %2:fpr(<16 x s8>) = COPY $q0 @@ -1829,8 +1835,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr64 = COPY $d2 ; CHECK: [[COPY1:%[0-9]+]]:fpr64 = COPY $d1 ; CHECK: [[COPY2:%[0-9]+]]:fpr64 = COPY $d0 - ; CHECK: [[MLAv4i16_:%[0-9]+]]:fpr64 = MLAv4i16 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv4i16_]] + ; CHECK: [[MULv4i16_:%[0-9]+]]:fpr64 = MULv4i16 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv4i16_:%[0-9]+]]:fpr64 = ADDv4i16 [[COPY2]], [[MULv4i16_]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv4i16_]] %4:fpr(<4 x s16>) = COPY $d2 %3:fpr(<4 x s16>) = COPY $d1 %2:fpr(<4 x s16>) = COPY $d0 @@ -1864,8 +1871,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr128 = COPY $q2 ; CHECK: [[COPY1:%[0-9]+]]:fpr128 = COPY $q1 ; CHECK: [[COPY2:%[0-9]+]]:fpr128 = COPY $q0 - ; CHECK: [[MLAv8i16_:%[0-9]+]]:fpr128 = MLAv8i16 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLAv8i16_]] + ; CHECK: [[MULv8i16_:%[0-9]+]]:fpr128 = MULv8i16 [[COPY1]], [[COPY]] + ; CHECK: [[ADDv8i16_:%[0-9]+]]:fpr128 = ADDv8i16 [[COPY2]], [[MULv8i16_]] + ; CHECK: $noreg = PATCHABLE_RET [[ADDv8i16_]] %4:fpr(<8 x s16>) = COPY $q2 %3:fpr(<8 x s16>) = COPY $q1 %2:fpr(<8 x s16>) = COPY $q0 @@ -2085,8 +2093,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr64 = COPY $d2 ; CHECK: [[COPY1:%[0-9]+]]:fpr64 = COPY $d1 ; CHECK: [[COPY2:%[0-9]+]]:fpr64 = COPY $d0 - ; CHECK: [[MLSv8i8_:%[0-9]+]]:fpr64 = MLSv8i8 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLSv8i8_]] + ; CHECK: [[MULv8i8_:%[0-9]+]]:fpr64 = MULv8i8 [[COPY1]], [[COPY]] + ; CHECK: [[SUBv8i8_:%[0-9]+]]:fpr64 = SUBv8i8 [[COPY2]], [[MULv8i8_]] + ; CHECK: $noreg = PATCHABLE_RET [[SUBv8i8_]] %4:fpr(<8 x s8>) = COPY $d2 %3:fpr(<8 x s8>) = COPY $d1 %2:fpr(<8 x s8>) = COPY $d0 @@ -2120,8 +2129,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr128 = COPY $q2 ; CHECK: [[COPY1:%[0-9]+]]:fpr128 = COPY $q1 ; CHECK: [[COPY2:%[0-9]+]]:fpr128 = COPY $q0 - ; CHECK: [[MLSv16i8_:%[0-9]+]]:fpr128 = MLSv16i8 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLSv16i8_]] + ; CHECK: [[MULv16i8_:%[0-9]+]]:fpr128 = MULv16i8 [[COPY1]], [[COPY]] + ; CHECK: [[SUBv16i8_:%[0-9]+]]:fpr128 = SUBv16i8 [[COPY2]], [[MULv16i8_]] + ; CHECK: $noreg = PATCHABLE_RET [[SUBv16i8_]] %4:fpr(<16 x s8>) = COPY $q2 %3:fpr(<16 x s8>) = COPY $q1 %2:fpr(<16 x s8>) = COPY $q0 @@ -2155,8 +2165,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr64 = COPY $d2 ; CHECK: [[COPY1:%[0-9]+]]:fpr64 = COPY $d1 ; CHECK: [[COPY2:%[0-9]+]]:fpr64 = COPY $d0 - ; CHECK: [[MLSv4i16_:%[0-9]+]]:fpr64 = MLSv4i16 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLSv4i16_]] + ; CHECK: [[MULv4i16_:%[0-9]+]]:fpr64 = MULv4i16 [[COPY1]], [[COPY]] + ; CHECK: [[SUBv4i16_:%[0-9]+]]:fpr64 = SUBv4i16 [[COPY2]], [[MULv4i16_]] + ; CHECK: $noreg = PATCHABLE_RET [[SUBv4i16_]] %4:fpr(<4 x s16>) = COPY $d2 %3:fpr(<4 x s16>) = COPY $d1 %2:fpr(<4 x s16>) = COPY $d0 @@ -2190,8 +2201,9 @@ ; CHECK: [[COPY:%[0-9]+]]:fpr128 = COPY $q2 ; CHECK: [[COPY1:%[0-9]+]]:fpr128 = COPY $q1 ; CHECK: [[COPY2:%[0-9]+]]:fpr128 = COPY $q0 - ; CHECK: [[MLSv8i16_:%[0-9]+]]:fpr128 = MLSv8i16 [[COPY2]], [[COPY1]], [[COPY]] - ; CHECK: $noreg = PATCHABLE_RET [[MLSv8i16_]] + ; CHECK: [[MULv8i16_:%[0-9]+]]:fpr128 = MULv8i16 [[COPY1]], [[COPY]] + ; CHECK: [[SUBv8i16_:%[0-9]+]]:fpr128 = SUBv8i16 [[COPY2]], [[MULv8i16_]] + ; CHECK: $noreg = PATCHABLE_RET [[SUBv8i16_]] %4:fpr(<8 x s16>) = COPY $q2 %3:fpr(<8 x s16>) = COPY $q1 %2:fpr(<8 x s16>) = COPY $q0 diff --git a/llvm/test/CodeGen/AArch64/overeager_mla_fusing.ll b/llvm/test/CodeGen/AArch64/overeager_mla_fusing.ll --- a/llvm/test/CodeGen/AArch64/overeager_mla_fusing.ll +++ b/llvm/test/CodeGen/AArch64/overeager_mla_fusing.ll @@ -5,17 +5,17 @@ ; CHECK-LABEL: jsimd_idct_ifast_neon_intrinsic: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: ldr q0, [x1, #32] -; CHECK-NEXT: ldr q1, [x0, #32] -; CHECK-NEXT: ldr q2, [x1, #96] +; CHECK-NEXT: ldr q1, [x1, #96] +; CHECK-NEXT: ldr q2, [x0, #32] ; CHECK-NEXT: ldr q3, [x0, #96] ; CHECK-NEXT: ldr x8, [x2, #48] -; CHECK-NEXT: mul v0.8h, v1.8h, v0.8h -; CHECK-NEXT: mov v1.16b, v0.16b -; CHECK-NEXT: mla v1.8h, v3.8h, v2.8h ; CHECK-NEXT: mov w9, w3 -; CHECK-NEXT: str q1, [x8, x9] +; CHECK-NEXT: mul v0.8h, v2.8h, v0.8h +; CHECK-NEXT: mul v1.8h, v3.8h, v1.8h +; CHECK-NEXT: add v2.8h, v0.8h, v1.8h +; CHECK-NEXT: str q2, [x8, x9] ; CHECK-NEXT: ldr x8, [x2, #56] -; CHECK-NEXT: mls v0.8h, v3.8h, v2.8h +; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h ; CHECK-NEXT: str q0, [x8, x9] ; CHECK-NEXT: ret entry: