diff --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h --- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h +++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h @@ -80,6 +80,7 @@ FMLAv4i32_indexed_OP2, FMLSv1i32_indexed_OP2, FMLSv1i64_indexed_OP2, + FMLSv4f16_OP1, FMLSv4f16_OP2, FMLSv8f16_OP1, FMLSv8f16_OP2, @@ -87,6 +88,7 @@ FMLSv2f32_OP2, FMLSv2f64_OP1, FMLSv2f64_OP2, + FMLSv4i16_indexed_OP1, FMLSv4i16_indexed_OP2, FMLSv8i16_indexed_OP1, FMLSv8i16_indexed_OP2, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -3806,8 +3806,8 @@ Found |= Match(AArch64::FMULv4i16_indexed, 2, MCP::FMLSv4i16_indexed_OP2) || Match(AArch64::FMULv4f16, 2, MCP::FMLSv4f16_OP2); - Found |= Match(AArch64::FMULv4i16_indexed, 1, MCP::FMLSv2i32_indexed_OP1) || - Match(AArch64::FMULv4f16, 1, MCP::FMLSv2f32_OP1); + Found |= Match(AArch64::FMULv4i16_indexed, 1, MCP::FMLSv4i16_indexed_OP1) || + Match(AArch64::FMULv4f16, 1, MCP::FMLSv4f16_OP1); break; case AArch64::FSUBv8f16: Found |= Match(AArch64::FMULv8i16_indexed, 2, MCP::FMLSv8i16_indexed_OP2) || @@ -3888,6 +3888,7 @@ case MachineCombinerPattern::FMLAv4f32_OP2: case MachineCombinerPattern::FMLAv4i32_indexed_OP1: case MachineCombinerPattern::FMLAv4i32_indexed_OP2: + case MachineCombinerPattern::FMLSv4i16_indexed_OP1: case MachineCombinerPattern::FMLSv4i16_indexed_OP2: case MachineCombinerPattern::FMLSv8i16_indexed_OP1: case MachineCombinerPattern::FMLSv8i16_indexed_OP2: @@ -3895,6 +3896,7 @@ case MachineCombinerPattern::FMLSv1i64_indexed_OP2: case MachineCombinerPattern::FMLSv2i32_indexed_OP2: case MachineCombinerPattern::FMLSv2i64_indexed_OP2: + case MachineCombinerPattern::FMLSv4f16_OP1: case MachineCombinerPattern::FMLSv4f16_OP2: case MachineCombinerPattern::FMLSv8f16_OP1: case MachineCombinerPattern::FMLSv8f16_OP2: @@ -4497,6 +4499,26 @@ FMAInstKind::Indexed); break; + case MachineCombinerPattern::FMLSv4f16_OP1: + case MachineCombinerPattern::FMLSv4i16_indexed_OP1: { + RC = &AArch64::FPR64RegClass; + Register NewVR = MRI.createVirtualRegister(RC); + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(AArch64::FNEGv4f16), NewVR) + .add(Root.getOperand(2)); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + if (Pattern == MachineCombinerPattern::FMLSv4f16_OP1) { + Opc = AArch64::FMLAv4f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator, &NewVR); + } else { + Opc = AArch64::FMLAv4i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed, &NewVR); + } + break; + } case MachineCombinerPattern::FMLSv4f16_OP2: RC = &AArch64::FPR64RegClass; Opc = AArch64::FMLSv4f16; @@ -4525,18 +4547,25 @@ break; case MachineCombinerPattern::FMLSv8f16_OP1: + case MachineCombinerPattern::FMLSv8i16_indexed_OP1: { RC = &AArch64::FPR128RegClass; - Opc = AArch64::FMLSv8f16; - MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, - FMAInstKind::Accumulator); - break; - case MachineCombinerPattern::FMLSv8i16_indexed_OP1: - RC = &AArch64::FPR128RegClass; - Opc = AArch64::FMLSv8i16_indexed; - MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, - FMAInstKind::Indexed); + Register NewVR = MRI.createVirtualRegister(RC); + MachineInstrBuilder MIB1 = + BuildMI(MF, Root.getDebugLoc(), TII->get(AArch64::FNEGv8f16), NewVR) + .add(Root.getOperand(2)); + InsInstrs.push_back(MIB1); + InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0)); + if (Pattern == MachineCombinerPattern::FMLSv8f16_OP1) { + Opc = AArch64::FMLAv8f16; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Accumulator, &NewVR); + } else { + Opc = AArch64::FMLAv8i16_indexed; + MUL = genFusedMultiply(MF, MRI, TII, Root, InsInstrs, 1, Opc, RC, + FMAInstKind::Indexed, &NewVR); + } break; - + } case MachineCombinerPattern::FMLSv8f16_OP2: RC = &AArch64::FPR128RegClass; Opc = AArch64::FMLSv8f16; diff --git a/llvm/test/CodeGen/AArch64/fp16-fmla.ll b/llvm/test/CodeGen/AArch64/fp16-fmla.ll --- a/llvm/test/CodeGen/AArch64/fp16-fmla.ll +++ b/llvm/test/CodeGen/AArch64/fp16-fmla.ll @@ -138,6 +138,16 @@ ret <8 x half> %add } +define <4 x half> @test_FMLSv4f16_OP1(<4 x half> %a, <4 x half> %b, <4 x half> %c) { +; CHECK-LABEL: test_FMLSv4f16_OP1: +; CHECK: fneg {{v[0-9]+}}.4h, {{v[0-9]+}}.4h +; CHECK: fmla {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h +entry: + %mul = fmul fast <4 x half> %c, %b + %sub = fsub fast <4 x half> %mul, %a + ret <4 x half> %sub +} + define <4 x half> @test_FMLSv4f16_OP2(<4 x half> %a, <4 x half> %b, <4 x half> %c) { ; CHECK-LABEL: test_FMLSv4f16_OP2: ; CHECK: fmls {{v[0-9]+}}.4h, {{v[0-9]+}}.4h, {{v[0-9]+}}.4h @@ -149,7 +159,8 @@ define <8 x half> @test_FMLSv8f16_OP1(<8 x half> %a, <8 x half> %b, <8 x half> %c) { ; CHECK-LABEL: test_FMLSv8f16_OP1: -; CHECK: fmls {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK: fneg {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK: fmla {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h entry: %mul = fmul fast <8 x half> %c, %b %sub = fsub fast <8 x half> %mul, %a @@ -185,7 +196,8 @@ ; CHECK: mul ; CHECK: fsub ; CHECK-FIXME: It should instead produce the following instruction: -; CHECK-FIXME: fmls {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK-FIXME: fneg {{v[0-9]+}}.8h, {{v[0-9]+}}.8h +; CHECK-FIXME: fmla {{v[0-9]+}}.8h, {{v[0-9]+}}.8h, {{v[0-9]+}}.8h entry: %mul = mul <8 x i16> %c, %b %m = bitcast <8 x i16> %mul to <8 x half>