diff --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h --- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h +++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h @@ -41,6 +41,7 @@ Symmetric, ReductionPHI, ReductionOperation, + ReductionSelect, }; enum class ComplexDeinterleavingRotation { diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -371,6 +371,10 @@ NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); + /// Identifies SelectInsts in a loop that has reduction with predication masks + /// and/or predicated tail folding + NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); + Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); /// Complete IR modifications after producing new reduction operation: @@ -889,6 +893,9 @@ if (NodePtr CN = identifyPHINode(Real, Imag)) return CN; + if (NodePtr CN = identifySelectNode(Real, Imag)) + return CN; + auto *VTy = cast(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); @@ -1713,6 +1720,45 @@ return submitCompositeNode(PlaceholderNode); } +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, + Instruction *Imag) { + auto *SelectReal = dyn_cast(Real); + auto *SelectImag = dyn_cast(Imag); + if (!SelectReal || !SelectImag) + return nullptr; + + Instruction *MaskA, *MaskB; + Instruction *AR, *AI, *RA, *BI; + if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), + m_Instruction(RA))) || + !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), + m_Instruction(BI)))) + return nullptr; + + if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) + return nullptr; + + if (!MaskA->getType()->isVectorTy()) + return nullptr; + + auto NodeA = identifyNode(AR, AI); + if (!NodeA) + return nullptr; + + auto NodeB = identifyNode(RA, BI); + if (!NodeB) + return nullptr; + + NodePtr PlaceholderNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); + PlaceholderNode->addOperand(NodeA); + PlaceholderNode->addOperand(NodeB); + FinalInstructions.insert(MaskA); + FinalInstructions.insert(MaskB); + return submitCompositeNode(PlaceholderNode); +} + static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB) { @@ -1787,6 +1833,19 @@ ReplacementNode = replaceNode(Builder, Node->Operands[0]); processReductionOperation(ReplacementNode, Node); break; + case ComplexDeinterleavingOperation::ReductionSelect: { + auto *MaskReal = Node->Real->getOperand(0); + auto *MaskImag = Node->Imag->getOperand(0); + auto *A = replaceNode(Builder, Node->Operands[0]); + auto *B = replaceNode(Builder, Node->Operands[1]); + auto *NewMaskTy = VectorType::getDoubleElementsVectorType( + cast(MaskReal->getType())); + auto *NewMask = + Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, + NewMaskTy, {MaskReal, MaskImag}); + ReplacementNode = Builder.CreateSelect(NewMask, A, B); + break; + } } assert(ReplacementNode && "Target failed to create Intrinsic call."); diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll --- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll @@ -20,8 +20,9 @@ ; CHECK-NEXT: mov x11, x10 ; CHECK-NEXT: mov z1.d, #0 // =0x0 ; CHECK-NEXT: rdvl x12, #2 -; CHECK-NEXT: mov z0.d, z1.d ; CHECK-NEXT: whilelo p1.d, xzr, x9 +; CHECK-NEXT: zip2 z0.d, z1.d, z1.d +; CHECK-NEXT: zip1 z1.d, z1.d, z1.d ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: .LBB0_1: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 @@ -29,29 +30,27 @@ ; CHECK-NEXT: add x14, x1, x8 ; CHECK-NEXT: zip1 p2.d, p1.d, p1.d ; CHECK-NEXT: zip2 p3.d, p1.d, p1.d -; CHECK-NEXT: add x8, x8, x12 +; CHECK-NEXT: mov z6.d, z1.d +; CHECK-NEXT: mov z7.d, z0.d ; CHECK-NEXT: ld1d { z2.d }, p3/z, [x13, #1, mul vl] ; CHECK-NEXT: ld1d { z3.d }, p2/z, [x13] ; CHECK-NEXT: ld1d { z4.d }, p3/z, [x14, #1, mul vl] ; CHECK-NEXT: ld1d { z5.d }, p2/z, [x14] -; CHECK-NEXT: uzp2 z6.d, z3.d, z2.d -; CHECK-NEXT: uzp1 z2.d, z3.d, z2.d -; CHECK-NEXT: uzp1 z7.d, z5.d, z4.d -; CHECK-NEXT: uzp2 z4.d, z5.d, z4.d -; CHECK-NEXT: movprfx z3, z1 -; CHECK-NEXT: fmla z3.d, p0/m, z7.d, z6.d -; CHECK-NEXT: fmad z7.d, p0/m, z2.d, z0.d -; CHECK-NEXT: fmad z2.d, p0/m, z4.d, z3.d -; CHECK-NEXT: movprfx z3, z7 -; CHECK-NEXT: fmls z3.d, p0/m, z4.d, z6.d -; CHECK-NEXT: mov z1.d, p1/m, z2.d -; CHECK-NEXT: mov z0.d, p1/m, z3.d ; CHECK-NEXT: whilelo p1.d, x11, x9 +; CHECK-NEXT: add x8, x8, x12 ; CHECK-NEXT: add x11, x11, x10 +; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0 +; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #0 +; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90 +; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90 +; CHECK-NEXT: mov z0.d, p3/m, z7.d +; CHECK-NEXT: mov z1.d, p2/m, z6.d ; CHECK-NEXT: b.mi .LBB0_1 ; CHECK-NEXT: // %bb.2: // %exit.block +; CHECK-NEXT: uzp2 z2.d, z1.d, z0.d +; CHECK-NEXT: uzp1 z0.d, z1.d, z0.d ; CHECK-NEXT: faddv d0, p0, z0.d -; CHECK-NEXT: faddv d1, p0, z1.d +; CHECK-NEXT: faddv d1, p0, z2.d ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 ; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1 ; CHECK-NEXT: ret @@ -122,39 +121,38 @@ ; CHECK-NEXT: and x11, x11, x12 ; CHECK-NEXT: mov z1.d, #0 // =0x0 ; CHECK-NEXT: rdvl x12, #2 -; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: zip2 z0.d, z1.d, z1.d +; CHECK-NEXT: zip1 z1.d, z1.d, z1.d ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: .LBB1_1: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: ld1w { z2.d }, p0/z, [x2, x9, lsl #2] ; CHECK-NEXT: add x13, x0, x8 ; CHECK-NEXT: add x14, x1, x8 +; CHECK-NEXT: mov z6.d, z1.d +; CHECK-NEXT: mov z7.d, z0.d ; CHECK-NEXT: add x9, x9, x10 ; CHECK-NEXT: add x8, x8, x12 -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, #0 -; CHECK-NEXT: zip1 p2.d, p1.d, p1.d -; CHECK-NEXT: zip2 p3.d, p1.d, p1.d -; CHECK-NEXT: ld1d { z2.d }, p3/z, [x13, #1, mul vl] -; CHECK-NEXT: ld1d { z3.d }, p2/z, [x13] -; CHECK-NEXT: ld1d { z4.d }, p3/z, [x14, #1, mul vl] -; CHECK-NEXT: ld1d { z5.d }, p2/z, [x14] +; CHECK-NEXT: cmpne p2.d, p0/z, z2.d, #0 +; CHECK-NEXT: zip1 p1.d, p2.d, p2.d +; CHECK-NEXT: zip2 p2.d, p2.d, p2.d +; CHECK-NEXT: ld1d { z2.d }, p2/z, [x13, #1, mul vl] +; CHECK-NEXT: ld1d { z3.d }, p1/z, [x13] +; CHECK-NEXT: ld1d { z4.d }, p2/z, [x14, #1, mul vl] +; CHECK-NEXT: ld1d { z5.d }, p1/z, [x14] ; CHECK-NEXT: cmp x11, x9 -; CHECK-NEXT: uzp2 z6.d, z3.d, z2.d -; CHECK-NEXT: uzp1 z2.d, z3.d, z2.d -; CHECK-NEXT: uzp1 z3.d, z5.d, z4.d -; CHECK-NEXT: movprfx z7, z0 -; CHECK-NEXT: fmla z7.d, p0/m, z3.d, z2.d -; CHECK-NEXT: fmad z3.d, p0/m, z6.d, z1.d -; CHECK-NEXT: uzp2 z4.d, z5.d, z4.d -; CHECK-NEXT: fmad z2.d, p0/m, z4.d, z3.d -; CHECK-NEXT: movprfx z5, z7 -; CHECK-NEXT: fmls z5.d, p0/m, z4.d, z6.d -; CHECK-NEXT: mov z0.d, p1/m, z5.d -; CHECK-NEXT: mov z1.d, p1/m, z2.d +; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0 +; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #0 +; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90 +; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90 +; CHECK-NEXT: mov z0.d, p2/m, z7.d +; CHECK-NEXT: mov z1.d, p1/m, z6.d ; CHECK-NEXT: b.ne .LBB1_1 ; CHECK-NEXT: // %bb.2: // %exit.block +; CHECK-NEXT: uzp2 z2.d, z1.d, z0.d +; CHECK-NEXT: uzp1 z0.d, z1.d, z0.d ; CHECK-NEXT: faddv d0, p0, z0.d -; CHECK-NEXT: faddv d1, p0, z1.d +; CHECK-NEXT: faddv d1, p0, z2.d ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 ; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1 ; CHECK-NEXT: ret @@ -223,42 +221,41 @@ ; CHECK-NEXT: mov x8, xzr ; CHECK-NEXT: mov x9, xzr ; CHECK-NEXT: mov z1.d, #0 // =0x0 -; CHECK-NEXT: mov z0.d, z1.d ; CHECK-NEXT: cntd x11 -; CHECK-NEXT: whilelo p1.d, xzr, x10 ; CHECK-NEXT: rdvl x12, #2 +; CHECK-NEXT: whilelo p1.d, xzr, x10 +; CHECK-NEXT: zip2 z0.d, z1.d, z1.d +; CHECK-NEXT: zip1 z1.d, z1.d, z1.d ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: .LBB2_1: // %vector.body ; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: ld1w { z2.d }, p1/z, [x2, x9, lsl #2] ; CHECK-NEXT: add x13, x0, x8 ; CHECK-NEXT: add x14, x1, x8 +; CHECK-NEXT: mov z6.d, z1.d +; CHECK-NEXT: mov z7.d, z0.d ; CHECK-NEXT: add x9, x9, x11 ; CHECK-NEXT: add x8, x8, x12 -; CHECK-NEXT: cmpne p2.d, p1/z, z2.d, #0 -; CHECK-NEXT: zip1 p1.d, p2.d, p2.d -; CHECK-NEXT: zip2 p3.d, p2.d, p2.d +; CHECK-NEXT: cmpne p1.d, p1/z, z2.d, #0 +; CHECK-NEXT: zip1 p2.d, p1.d, p1.d +; CHECK-NEXT: zip2 p3.d, p1.d, p1.d ; CHECK-NEXT: ld1d { z2.d }, p3/z, [x13, #1, mul vl] -; CHECK-NEXT: ld1d { z3.d }, p1/z, [x13] +; CHECK-NEXT: ld1d { z3.d }, p2/z, [x13] ; CHECK-NEXT: ld1d { z4.d }, p3/z, [x14, #1, mul vl] -; CHECK-NEXT: ld1d { z5.d }, p1/z, [x14] +; CHECK-NEXT: ld1d { z5.d }, p2/z, [x14] ; CHECK-NEXT: whilelo p1.d, x9, x10 -; CHECK-NEXT: uzp1 z6.d, z3.d, z2.d -; CHECK-NEXT: uzp2 z2.d, z3.d, z2.d -; CHECK-NEXT: uzp1 z7.d, z5.d, z4.d -; CHECK-NEXT: uzp2 z4.d, z5.d, z4.d -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: fmla z3.d, p0/m, z7.d, z6.d -; CHECK-NEXT: fmad z7.d, p0/m, z2.d, z1.d -; CHECK-NEXT: fmsb z2.d, p0/m, z4.d, z3.d -; CHECK-NEXT: movprfx z3, z7 -; CHECK-NEXT: fmla z3.d, p0/m, z4.d, z6.d -; CHECK-NEXT: mov z1.d, p2/m, z3.d -; CHECK-NEXT: mov z0.d, p2/m, z2.d +; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #0 +; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #0 +; CHECK-NEXT: fcmla z6.d, p0/m, z5.d, z3.d, #90 +; CHECK-NEXT: fcmla z7.d, p0/m, z4.d, z2.d, #90 +; CHECK-NEXT: mov z0.d, p3/m, z7.d +; CHECK-NEXT: mov z1.d, p2/m, z6.d ; CHECK-NEXT: b.mi .LBB2_1 ; CHECK-NEXT: // %bb.2: // %exit.block +; CHECK-NEXT: uzp2 z2.d, z1.d, z0.d +; CHECK-NEXT: uzp1 z0.d, z1.d, z0.d ; CHECK-NEXT: faddv d0, p0, z0.d -; CHECK-NEXT: faddv d1, p0, z1.d +; CHECK-NEXT: faddv d1, p0, z2.d ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0 ; CHECK-NEXT: // kill: def $d1 killed $d1 killed $z1 ; CHECK-NEXT: ret