diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -23017,19 +23017,20 @@ // Check ExtractLow's user. if (HasFoundMULLow) { SDNode *ExtractLowUser = *ExtractLow.getNode()->use_begin(); - if (ExtractLowUser->getOpcode() != N->getOpcode()) + if (ExtractLowUser->getOpcode() != N->getOpcode()) { HasFoundMULLow = false; - - if (ExtractLowUser->getOperand(0) == ExtractLow) { - if (ExtractLowUser->getOperand(1).getOpcode() == ISD::TRUNCATE) - TruncLow = ExtractLowUser->getOperand(1); - else - HasFoundMULLow = false; } else { - if (ExtractLowUser->getOperand(0).getOpcode() == ISD::TRUNCATE) - TruncLow = ExtractLowUser->getOperand(0); - else - HasFoundMULLow = false; + if (ExtractLowUser->getOperand(0) == ExtractLow) { + if (ExtractLowUser->getOperand(1).getOpcode() == ISD::TRUNCATE) + TruncLow = ExtractLowUser->getOperand(1); + else + HasFoundMULLow = false; + } else { + if (ExtractLowUser->getOperand(0).getOpcode() == ISD::TRUNCATE) + TruncLow = ExtractLowUser->getOperand(0); + else + HasFoundMULLow = false; + } } } diff --git a/llvm/test/CodeGen/AArch64/aarch64-smull.ll b/llvm/test/CodeGen/AArch64/aarch64-smull.ll --- a/llvm/test/CodeGen/AArch64/aarch64-smull.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-smull.ll @@ -1330,8 +1330,27 @@ ret void } +define <2 x i32> @do_stuff(<2 x i64> %0, <2 x i64> %1) { +; CHECK-LABEL: do_stuff: +; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 v0.4s, v0.4s, v0.4s +; CHECK-NEXT: smull2 v0.2d, v1.4s, v0.4s +; CHECK-NEXT: xtn v0.2s, v0.2d +; CHECK-NEXT: add v0.2s, v0.2s, v1.2s +; CHECK-NEXT: ret + %bc.1 = bitcast <2 x i64> %1 to <4 x i32> + %trunc.0 = trunc <2 x i64> %0 to <2 x i32> + %shuff.hi = shufflevector <4 x i32> %bc.1, <4 x i32> zeroinitializer, <2 x i32> <i32 2, i32 3> + %shuff.lo = shufflevector <4 x i32> %bc.1, <4 x i32> zeroinitializer, <2 x i32> <i32 0, i32 1> + %smull = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %shuff.hi, <2 x i32> %trunc.0) + %trunc.smull = trunc <2 x i64> %smull to <2 x i32> + %final = add <2 x i32> %trunc.smull, %shuff.lo + ret <2 x i32> %final +} + declare <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8>, <8 x i8>) declare <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8>, <8 x i8>) declare <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8>, <8 x i8>) declare <4 x i32> @llvm.aarch64.neon.smull.v4i32(<4 x i16>, <4 x i16>) declare <4 x i32> @llvm.aarch64.neon.umull.v4i32(<4 x i16>, <4 x i16>) +declare <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32>, <2 x i32>)