diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -4198,6 +4198,40 @@ FMAInstKind::Accumulator); } +/// genNeg - Helper to generate an intermediate negation of the second operand +/// of Root +static Register genNeg(MachineFunction &MF, MachineRegisterInfo &MRI, + const TargetInstrInfo *TII, MachineInstr &Root, + SmallVectorImpl &InsInstrs, + DenseMap &InstrIdxForVirtReg, + unsigned MnegOpc, const TargetRegisterClass *RC) { + Register NewVR = MRI.createVirtualRegister(RC); + MachineInstrBuilder MIB = + BuildMI(MF, Root.getDebugLoc(), TII->get(MnegOpc), NewVR) + .add(Root.getOperand(2)); + InsInstrs.push_back(MIB); + + assert(InstrIdxForVirtReg.empty()); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + + return NewVR; +} + +/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate +/// instructions with an additional negation of the accumulator +static MachineInstr *genFusedMultiplyAccNeg( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl &InsInstrs, + DenseMap &InstrIdxForVirtReg, unsigned IdxMulOpd, + unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) { + assert(IdxMulOpd == 1); + + Register NewVR = + genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC); + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Accumulator, &NewVR); +} + /// genFusedMultiplyIdx - Helper to generate fused multiply accumulate /// instructions. /// @@ -4210,6 +4244,22 @@ FMAInstKind::Indexed); } +/// genFusedMultiplyAccNeg - Helper to generate fused multiply accumulate +/// instructions with an additional negation of the accumulator +static MachineInstr *genFusedMultiplyIdxNeg( + MachineFunction &MF, MachineRegisterInfo &MRI, const TargetInstrInfo *TII, + MachineInstr &Root, SmallVectorImpl &InsInstrs, + DenseMap &InstrIdxForVirtReg, unsigned IdxMulOpd, + unsigned MaddOpc, unsigned MnegOpc, const TargetRegisterClass *RC) { + assert(IdxMulOpd == 1); + + Register NewVR = + genNeg(MF, MRI, TII, Root, InsInstrs, InstrIdxForVirtReg, MnegOpc, RC); + + return genFusedMultiply(MF, MRI, TII, Root, InsInstrs, IdxMulOpd, MaddOpc, RC, + FMAInstKind::Indexed, &NewVR); +} + /// genMaddR - Generate madd instruction and combine mul and add using /// an extra virtual register /// Example - an ADD intermediate needs to be stored in a register: @@ -4512,9 +4562,11 @@ break; case MachineCombinerPattern::MULSUBv8i8_OP1: - Opc = AArch64::MLSv8i8; + Opc = AArch64::MLAv8i8; RC = &AArch64::FPR64RegClass; - MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i8, + RC); break; case MachineCombinerPattern::MULSUBv8i8_OP2: Opc = AArch64::MLSv8i8; @@ -4522,9 +4574,11 @@ MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv16i8_OP1: - Opc = AArch64::MLSv16i8; + Opc = AArch64::MLAv16i8; RC = &AArch64::FPR128RegClass; - MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv16i8, + RC); break; case MachineCombinerPattern::MULSUBv16i8_OP2: Opc = AArch64::MLSv16i8; @@ -4532,9 +4586,11 @@ MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv4i16_OP1: - Opc = AArch64::MLSv4i16; + Opc = AArch64::MLAv4i16; RC = &AArch64::FPR64RegClass; - MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16, + RC); break; case MachineCombinerPattern::MULSUBv4i16_OP2: Opc = AArch64::MLSv4i16; @@ -4542,9 +4598,11 @@ MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv8i16_OP1: - Opc = AArch64::MLSv8i16; + Opc = AArch64::MLAv8i16; RC = &AArch64::FPR128RegClass; - MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16, + RC); break; case MachineCombinerPattern::MULSUBv8i16_OP2: Opc = AArch64::MLSv8i16; @@ -4552,9 +4610,11 @@ MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv2i32_OP1: - Opc = AArch64::MLSv2i32; + Opc = AArch64::MLAv2i32; RC = &AArch64::FPR64RegClass; - MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32, + RC); break; case MachineCombinerPattern::MULSUBv2i32_OP2: Opc = AArch64::MLSv2i32; @@ -4562,9 +4622,11 @@ MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv4i32_OP1: - Opc = AArch64::MLSv4i32; + Opc = AArch64::MLAv4i32; RC = &AArch64::FPR128RegClass; - MUL = genFusedMultiplyAcc(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyAccNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32, + RC); break; case MachineCombinerPattern::MULSUBv4i32_OP2: Opc = AArch64::MLSv4i32; @@ -4614,9 +4676,11 @@ break; case MachineCombinerPattern::MULSUBv4i16_indexed_OP1: - Opc = AArch64::MLSv4i16_indexed; + Opc = AArch64::MLAv4i16_indexed; RC = &AArch64::FPR64RegClass; - MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i16, + RC); break; case MachineCombinerPattern::MULSUBv4i16_indexed_OP2: Opc = AArch64::MLSv4i16_indexed; @@ -4624,9 +4688,11 @@ MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv8i16_indexed_OP1: - Opc = AArch64::MLSv8i16_indexed; + Opc = AArch64::MLAv8i16_indexed; RC = &AArch64::FPR128RegClass; - MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv8i16, + RC); break; case MachineCombinerPattern::MULSUBv8i16_indexed_OP2: Opc = AArch64::MLSv8i16_indexed; @@ -4634,9 +4700,11 @@ MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv2i32_indexed_OP1: - Opc = AArch64::MLSv2i32_indexed; + Opc = AArch64::MLAv2i32_indexed; RC = &AArch64::FPR64RegClass; - MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv2i32, + RC); break; case MachineCombinerPattern::MULSUBv2i32_indexed_OP2: Opc = AArch64::MLSv2i32_indexed; @@ -4644,9 +4712,11 @@ MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 2, Opc, RC); break; case MachineCombinerPattern::MULSUBv4i32_indexed_OP1: - Opc = AArch64::MLSv4i32_indexed; + Opc = AArch64::MLAv4i32_indexed; RC = &AArch64::FPR128RegClass; - MUL = genFusedMultiplyIdx(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC); + MUL = genFusedMultiplyIdxNeg(MF, MRI, TII, Root, InsInstrs, + InstrIdxForVirtReg, 1, Opc, AArch64::NEGv4i32, + RC); break; case MachineCombinerPattern::MULSUBv4i32_indexed_OP2: Opc = AArch64::MLSv4i32_indexed; diff --git a/llvm/test/CodeGen/AArch64/neon-mla-mls.ll b/llvm/test/CodeGen/AArch64/neon-mla-mls.ll --- a/llvm/test/CodeGen/AArch64/neon-mla-mls.ll +++ b/llvm/test/CodeGen/AArch64/neon-mla-mls.ll @@ -135,3 +135,75 @@ } +define <8 x i8> @mls2v8xi8(<8 x i8> %A, <8 x i8> %B, <8 x i8> %C) { +; CHECK-LABEL: mls2v8xi8: +; CHECK: // %bb.0: +; CHECK-NEXT: neg v2.8b, v2.8b +; CHECK-NEXT: mla v2.8b, v0.8b, v1.8b +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret + %tmp1 = mul <8 x i8> %A, %B; + %tmp2 = sub <8 x i8> %tmp1, %C; + ret <8 x i8> %tmp2 +} + +define <16 x i8> @mls2v16xi8(<16 x i8> %A, <16 x i8> %B, <16 x i8> %C) { +; CHECK-LABEL: mls2v16xi8: +; CHECK: // %bb.0: +; CHECK-NEXT: neg v2.16b, v2.16b +; CHECK-NEXT: mla v2.16b, v0.16b, v1.16b +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret + %tmp1 = mul <16 x i8> %A, %B; + %tmp2 = sub <16 x i8> %tmp1, %C; + ret <16 x i8> %tmp2 +} + +define <4 x i16> @mls2v4xi16(<4 x i16> %A, <4 x i16> %B, <4 x i16> %C) { +; CHECK-LABEL: mls2v4xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: neg v2.4h, v2.4h +; CHECK-NEXT: mla v2.4h, v0.4h, v1.4h +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret + %tmp1 = mul <4 x i16> %A, %B; + %tmp2 = sub <4 x i16> %tmp1, %C; + ret <4 x i16> %tmp2 +} + +define <8 x i16> @mls2v8xi16(<8 x i16> %A, <8 x i16> %B, <8 x i16> %C) { +; CHECK-LABEL: mls2v8xi16: +; CHECK: // %bb.0: +; CHECK-NEXT: neg v2.8h, v2.8h +; CHECK-NEXT: mla v2.8h, v0.8h, v1.8h +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret + %tmp1 = mul <8 x i16> %A, %B; + %tmp2 = sub <8 x i16> %tmp1, %C; + ret <8 x i16> %tmp2 +} + +define <2 x i32> @mls2v2xi32(<2 x i32> %A, <2 x i32> %B, <2 x i32> %C) { +; CHECK-LABEL: mls2v2xi32: +; CHECK: // %bb.0: +; CHECK-NEXT: neg v2.2s, v2.2s +; CHECK-NEXT: mla v2.2s, v0.2s, v1.2s +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret + %tmp1 = mul <2 x i32> %A, %B; + %tmp2 = sub <2 x i32> %tmp1, %C; + ret <2 x i32> %tmp2 +} + +define <4 x i32> @mls2v4xi32(<4 x i32> %A, <4 x i32> %B, <4 x i32> %C) { +; CHECK-LABEL: mls2v4xi32: +; CHECK: // %bb.0: +; CHECK-NEXT: neg v2.4s, v2.4s +; CHECK-NEXT: mla v2.4s, v0.4s, v1.4s +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret + %tmp1 = mul <4 x i32> %A, %B; + %tmp2 = sub <4 x i32> %tmp1, %C; + ret <4 x i32> %tmp2 +} +