Index: llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp +++ llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp @@ -61,6 +61,27 @@ // %6:fpr128 = IMPLICIT_DEF // %7:fpr128 = INSERT_SUBREG %6:fpr128(tied-def 0), killed %1:fpr64, %subreg.dsub // +// 8. If a SMLSL's first operand is SMLSL, there could be a chance for below +// transformation. +// +// XTN + XTN + SMLSL + SMLSL ==> UZP1 + SMLSL + SMLSL2 +// +// For example, +// +// %5:fpr64 = XTNv4i16 killed %4:fpr128 +// %7:fpr64 = XTNv4i16 killed %6:fpr128 +// %8:fpr64 = COPY %0.dsub:fpr128 +// %9:fpr128 = EXTv16i8 %0:fpr128, %0:fpr128, 8 +// %10:fpr64 = COPY %9.dsub:fpr128 +// %11:fpr128 = SMLSLv4i16_v4i32 %1:fpr128(tied-def 0), killed %8:fpr64, killed %5:fpr64 +// %12:fpr128 = SMLSLv4i16_v4i32 %11:fpr128(tied-def 0), killed %10:fpr64, killed %7:fpr64 +// ==> +// %6:fpr128 = UZP1v8i16 killed %4:fpr128, killed %5:fpr128 +// %7:fpr64 = COPY %0.dsub:fpr128 +// %8:fpr64 = COPY %6.dsub:fpr128 +// %9:fpr128 = SMLSLv4i16_v4i32 %1:fpr128(tied-def 0), killed %7:fpr64, killed %8:fpr64 +// %10:fpr128 = SMLSLv8i16_v4i32 %9:fpr128(tied-def 0), %0:fpr128, %6:fpr128 +// //===----------------------------------------------------------------------===// #include "AArch64ExpandImm.h" @@ -127,6 +148,7 @@ bool visitINSERT(MachineInstr &MI); bool visitINSviGPR(MachineInstr &MI, unsigned Opc); bool visitINSvi64lane(MachineInstr &MI); + bool visitSMLSLv4i16_v4i32(MachineInstr &MI); bool runOnMachineFunction(MachineFunction &MF) override; StringRef getPassName() const override { @@ -669,6 +691,111 @@ return true; } +bool AArch64MIPeepholeOpt::visitSMLSLv4i16_v4i32(MachineInstr &MI) { + // If SMLSL's first operand is SMLSL, there could be a chance for below + // transformation. + // + // XTN + XTN + SMLSL + SMLSL ==> UZP1 + SMLSL + SMLSL2 + // + // For example, + // + // %5:fpr64 = XTNv4i16 killed %4:fpr128 + // %7:fpr64 = XTNv4i16 killed %6:fpr128 + // %8:fpr64 = COPY %0.dsub:fpr128 + // %9:fpr128 = EXTv16i8 %0:fpr128, %0:fpr128, 8 + // %10:fpr64 = COPY %9.dsub:fpr128 + // %11:fpr128 = SMLSLv4i16_v4i32 %1:fpr128(tied-def 0), killed %8:fpr64, killed %5:fpr64 + // %12:fpr128 = SMLSLv4i16_v4i32 %11:fpr128(tied-def 0), killed %10:fpr64, killed %7:fpr64 + // ==> + // %6:fpr128 = UZP1v8i16 killed %4:fpr128, killed %5:fpr128 + // %7:fpr64 = COPY %0.dsub:fpr128 + // %8:fpr64 = COPY %6.dsub:fpr128 + // %9:fpr128 = SMLSLv4i16_v4i32 %1:fpr128(tied-def 0), killed %7:fpr64, killed %8:fpr64 + // %10:fpr128 = SMLSLv8i16_v4i32 %9:fpr128(tied-def 0), %0:fpr128, %6:fpr128 + + MachineInstr *SecondSMLSLMI = &MI; + // Check SMLSL's first operand is SMLSL. + MachineInstr *FirstSMLSLMI = + MRI->getUniqueVRegDef(SecondSMLSLMI->getOperand(1).getReg()); + if (FirstSMLSLMI->getOpcode() != AArch64::SMLSLv4i16_v4i32) + return false; + + // Check SMLSL's third operand is XTN. + MachineInstr *FirstXTNMI = + MRI->getUniqueVRegDef(FirstSMLSLMI->getOperand(3).getReg()); + if (FirstXTNMI->getOpcode() != AArch64::XTNv4i16) + return false; + + MachineInstr *SecondXTNMI = + MRI->getUniqueVRegDef(SecondSMLSLMI->getOperand(3).getReg()); + if (SecondXTNMI->getOpcode() != AArch64::XTNv4i16) + return false; + + // Check second SMLSL's second operand is COPY an its operand is FPR128. + MachineInstr *SecondSMLSLOp2MI = + MRI->getUniqueVRegDef(SecondSMLSLMI->getOperand(2).getReg()); + if (SecondSMLSLOp2MI->getOpcode() != TargetOpcode::COPY) + return false; + if (SecondSMLSLOp2MI->getOperand(1).getSubReg() != AArch64::dsub) + return false; + + // If the copy's operand MI is EXT and the EXT uses same operand for op1 and + // op2, update SecondSMLSLOp2Reg. + Register SecondSMLSLOp2Reg = SecondSMLSLOp2MI->getOperand(1).getReg(); + MachineInstr *EXTMI = MRI->getUniqueVRegDef(SecondSMLSLOp2Reg); + if (EXTMI->getOpcode() == AArch64::EXTv16i8) { + if (EXTMI->getOperand(1).getReg() == EXTMI->getOperand(2).getReg()) { + SecondSMLSLOp2Reg = EXTMI->getOperand(1).getReg(); + EXTMI->eraseFromParent(); + } + } + + // Let's create UZP1 from XTN + XTN. + Register UZP1Low64Reg = FirstXTNMI->getOperand(1).getReg(); + Register UZP1High64Reg = SecondXTNMI->getOperand(1).getReg(); + Register UZP1DstReg = MRI->createVirtualRegister(&AArch64::FPR128RegClass); + BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(AArch64::UZP1v8i16), + UZP1DstReg) + .addUse(UZP1Low64Reg, getRegState(FirstXTNMI->getOperand(1))) + .addUse(UZP1High64Reg, getRegState(SecondXTNMI->getOperand(1))); + + // Let's create new SMLSL. + Register COPYLow64DstReg = + MRI->createVirtualRegister(&AArch64::FPR64RegClass); + BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), + COPYLow64DstReg) + .addUse(UZP1DstReg, 0, AArch64::dsub); + + Register NewSMLSLDstReg = FirstSMLSLMI->getOperand(0).getReg(); + Register NewSMLSLOp1Reg = FirstSMLSLMI->getOperand(1).getReg(); + Register NewSMLSLOp2Reg = FirstSMLSLMI->getOperand(2).getReg(); + Register NewSMLSLOp3Reg = COPYLow64DstReg; + BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), + TII->get(AArch64::SMLSLv4i16_v4i32), NewSMLSLDstReg) + .addUse(NewSMLSLOp1Reg, getRegState(FirstSMLSLMI->getOperand(1))) + .addUse(NewSMLSLOp2Reg, getRegState(FirstSMLSLMI->getOperand(2))) + .addUse(NewSMLSLOp3Reg, getRegState(FirstSMLSLMI->getOperand(3))); + + // Let's create new SMLSL2. + Register NewSMLSL2DstReg = SecondSMLSLMI->getOperand(0).getReg(); + Register NewSMLSL2Op1Reg = NewSMLSLDstReg; + Register NewSMLSL2Op2Reg = SecondSMLSLOp2Reg; + Register NewSMLSL2Op3Reg = UZP1DstReg; + BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), + TII->get(AArch64::SMLSLv8i16_v4i32), NewSMLSL2DstReg) + .addUse(NewSMLSL2Op1Reg, getRegState(SecondSMLSLMI->getOperand(1))) + .addUse(NewSMLSL2Op2Reg, getRegState(SecondSMLSLMI->getOperand(2))) + .addUse(NewSMLSL2Op3Reg, getRegState(SecondSMLSLMI->getOperand(3))); + + SecondSMLSLMI->eraseFromParent(); + FirstSMLSLMI->eraseFromParent(); + SecondSMLSLOp2MI->eraseFromParent(); + SecondXTNMI->eraseFromParent(); + FirstXTNMI->eraseFromParent(); + + return true; +} + bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) { if (skipFunction(MF.getFunction())) return false; @@ -747,6 +874,9 @@ case AArch64::INSvi64lane: Changed = visitINSvi64lane(MI); break; + case AArch64::SMLSLv4i16_v4i32: + Changed = visitSMLSLv4i16_v4i32(MI); + break; } } } Index: llvm/test/CodeGen/AArch64/aarch64-smull.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -1115,3 +1115,30 @@ %out = mul nsw <4 x i64> %in1, %broadcast.splat ret <4 x i64> %out } + +define void @smlsl_smlsl2(<8 x i16> %0, <4 x i32> %1, ptr %2, ptr %3, i32 %4) { +; CHECK-LABEL: smlsl_smlsl2: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ldp q2, q3, [x1] +; CHECK-NEXT: uzp1 v2.8h, v2.8h, v3.8h +; CHECK-NEXT: smlsl v1.4s, v0.4h, v2.4h +; CHECK-NEXT: smlsl2 v1.4s, v0.8h, v2.8h +; CHECK-NEXT: str q1, [x0] +; CHECK-NEXT: ret +entry: + %5 = load <4 x i32>, ptr %3, align 4 + %6 = trunc <4 x i32> %5 to <4 x i16> + %7 = getelementptr inbounds i32, ptr %3, i64 4 + %8 = load <4 x i32>, ptr %7, align 4 + %9 = trunc <4 x i32> %8 to <4 x i16> + %10 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %11 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %10, <4 x i16> %6) + %12 = shufflevector <8 x i16> %0, <8 x i16> poison, <4 x i32> + %13 = tail call <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16> %12, <4 x i16> %9) + %14 = add <4 x i32> %11, %13 + %15 = sub <4 x i32> %1, %14 + store <4 x i32> %15, ptr %2, align 16 + ret void +} + +declare <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16>, <4 x i16>)