diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10572,6 +10572,25 @@ return true; } + case Instruction::Mul: { + + for (auto &Op : I->operands()) { + // Make sure we are not already sinking this operand + if (any_of(Ops, [&](Use *U) { return U->get() == Op; })) + continue; + + ShuffleVectorInst *Shuffle = dyn_cast(Op); + if (!Shuffle) + continue; + + // TODO Why can this sometimes fail? + if (isa(&Shuffle->getOperandUse(0))) + Ops.push_back(&Shuffle->getOperandUse(0)); + Ops.push_back(&Op); + } + + return true; + } default: return false; } diff --git a/llvm/test/CodeGen/AArch64/aarch64-matrix-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-matrix-smull.ll --- a/llvm/test/CodeGen/AArch64/aarch64-matrix-smull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-matrix-smull.ll @@ -4,10 +4,10 @@ define void @matrix_mul_signed(i32 %N, i32* nocapture %C, i16* nocapture readonly %A, i16 %val) { ; CHECK-LABEL: matrix_mul_signed: ; CHECK: // %bb.0: // %vector.header -; CHECK-NEXT: dup v0.4h, w3 +; CHECK-NEXT: sxth w9, w3 ; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 -; CHECK-NEXT: sshll v0.4s, v0.4h, #0 ; CHECK-NEXT: and x8, x0, #0xfffffff8 +; CHECK-NEXT: dup v0.4h, w9 ; CHECK-NEXT: .LBB0_1: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: add x9, x2, w0, sxtw #1 @@ -15,10 +15,8 @@ ; CHECK-NEXT: add x9, x1, w0, sxtw #2 ; CHECK-NEXT: subs x8, x8, #8 // =8 ; CHECK-NEXT: add w0, w0, #8 // =8 -; CHECK-NEXT: sshll v1.4s, v1.4h, #0 -; CHECK-NEXT: sshll v2.4s, v2.4h, #0 -; CHECK-NEXT: mul v1.4s, v0.4s, v1.4s -; CHECK-NEXT: mul v2.4s, v0.4s, v2.4s +; CHECK-NEXT: smull v1.4s, v0.4h, v1.4h +; CHECK-NEXT: smull v2.4s, v0.4h, v2.4h ; CHECK-NEXT: stp q1, q2, [x9] ; CHECK-NEXT: b.ne .LBB0_1 ; CHECK-NEXT: // %bb.2: // %for.end12 diff --git a/llvm/test/CodeGen/AArch64/aarch64-matrix-umull.ll b/llvm/test/CodeGen/AArch64/aarch64-matrix-umull.ll --- a/llvm/test/CodeGen/AArch64/aarch64-matrix-umull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-matrix-umull.ll @@ -7,7 +7,7 @@ ; CHECK-NEXT: and w9, w3, #0xffff ; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 ; CHECK-NEXT: and x8, x0, #0xfffffff8 -; CHECK-NEXT: dup v0.4s, w9 +; CHECK-NEXT: dup v0.4h, w9 ; CHECK-NEXT: .LBB0_1: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: add x9, x2, w0, uxtw #1 @@ -15,10 +15,8 @@ ; CHECK-NEXT: add x9, x1, w0, uxtw #2 ; CHECK-NEXT: subs x8, x8, #8 // =8 ; CHECK-NEXT: add w0, w0, #8 // =8 -; CHECK-NEXT: ushll v1.4s, v1.4h, #0 -; CHECK-NEXT: ushll v2.4s, v2.4h, #0 -; CHECK-NEXT: mul v1.4s, v0.4s, v1.4s -; CHECK-NEXT: mul v2.4s, v0.4s, v2.4s +; CHECK-NEXT: umull v1.4s, v0.4h, v1.4h +; CHECK-NEXT: umull v2.4s, v0.4h, v2.4h ; CHECK-NEXT: stp q1, q2, [x9] ; CHECK-NEXT: b.ne .LBB0_1 ; CHECK-NEXT: // %bb.2: // %for.end12