diff --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h --- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h +++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h @@ -38,6 +38,7 @@ // The following 'operations' are used to represent internal states. Backends // are not expected to try and support these in any capacity. Deinterleave, + Splat, Symmetric, ReductionPHI, ReductionOperation, diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -130,7 +130,7 @@ struct ComplexDeinterleavingCompositeNode { ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, - Instruction *R, Instruction *I) + Value *R, Value *I) : Operation(Op), Real(R), Imag(I) {} private: @@ -140,8 +140,8 @@ public: ComplexDeinterleavingOperation Operation; - Instruction *Real; - Instruction *Imag; + Value *Real; + Value *Imag; // This two members are required exclusively for generating // ComplexDeinterleavingOperation::Symmetric operations. @@ -192,19 +192,19 @@ class ComplexDeinterleavingGraph { public: struct Product { - Instruction *Multiplier; - Instruction *Multiplicand; + Value *Multiplier; + Value *Multiplicand; bool IsPositive; }; - using Addend = std::pair; + using Addend = std::pair; using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; // Helper struct for holding info about potential partial multiplication // candidates struct PartialMulCandidate { - Instruction *Common; + Value *Common; NodePtr Node; unsigned RealIdx; unsigned ImagIdx; @@ -270,7 +270,7 @@ std::map OldToNewPHI; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, - Instruction *R, Instruction *I) { + Value *R, Value *I) { assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && Operation != ComplexDeinterleavingOperation::ReductionOperation) || (R && I)) && @@ -308,9 +308,9 @@ /// Identify the other branch of a Partial Mul, taking the CommonOperandI that /// is partially known from identifyPartialMul, filling in the other half of /// the complex pair. - NodePtr identifyNodeWithImplicitAdd( - Instruction *I, Instruction *J, - std::pair &CommonOperandI); + NodePtr + identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, + std::pair &CommonOperandI); /// Identifies a complex add pattern and its rotation, based on the following /// patterns. @@ -322,7 +322,7 @@ NodePtr identifyAdd(Instruction *Real, Instruction *Imag); NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); - NodePtr identifyNode(Instruction *I, Instruction *J); + NodePtr identifyNode(Value *R, Value *I); /// Determine if a sum of complex numbers can be formed from \p RealAddends /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. @@ -369,6 +369,12 @@ /// intrinsic (for both fixed and scalable vectors) NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); + /// identifying the operation that represents a complex number repeated in a + /// Splat vector. There are two possible types of splats: ConstantExpr with + /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an + /// initialization mask with all values set to zero. + NodePtr identifySplat(Value *Real, Value *Imag); + NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); /// Identifies SelectInsts in a loop that has reduction with predication masks @@ -521,7 +527,7 @@ ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( Instruction *Real, Instruction *Imag, - std::pair &PartialMatch) { + std::pair &PartialMatch) { LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag << "\n"); @@ -536,52 +542,43 @@ return nullptr; } - Instruction *R0 = dyn_cast(Real->getOperand(0)); - Instruction *R1 = dyn_cast(Real->getOperand(1)); - Instruction *I0 = dyn_cast(Imag->getOperand(0)); - Instruction *I1 = dyn_cast(Imag->getOperand(1)); - if (!R0 || !R1 || !I0 || !I1) { - LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); - return nullptr; - } + Value *R0 = Real->getOperand(0); + Value *R1 = Real->getOperand(1); + Value *I0 = Imag->getOperand(0); + Value *I1 = Imag->getOperand(1); // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the // rotations and use the operand. unsigned Negs = 0; SmallVector FNegs; - if (R0->getOpcode() == Instruction::FNeg || - R1->getOpcode() == Instruction::FNeg) { + Value *Op; + if (match(R0, m_Neg(m_Value(Op)))) { Negs |= 1; - if (R0->getOpcode() == Instruction::FNeg) { - FNegs.push_back(R0); - R0 = dyn_cast(R0->getOperand(0)); - } else { - FNegs.push_back(R1); - R1 = dyn_cast(R1->getOperand(0)); - } - if (!R0 || !R1) - return nullptr; + FNegs.push_back(cast(R0)); + R0 = Op; + } else if (match(R1, m_Neg(m_Value(Op)))) { + Negs |= 1; + FNegs.push_back(cast(R1)); + R1 = Op; } - if (I0->getOpcode() == Instruction::FNeg || - I1->getOpcode() == Instruction::FNeg) { + + if (match(I0, m_Neg(m_Value(Op)))) { Negs |= 2; Negs ^= 1; - if (I0->getOpcode() == Instruction::FNeg) { - FNegs.push_back(I0); - I0 = dyn_cast(I0->getOperand(0)); - } else { - FNegs.push_back(I1); - I1 = dyn_cast(I1->getOperand(0)); - } - if (!I0 || !I1) - return nullptr; + FNegs.push_back(cast(I0)); + I0 = Op; + } else if (match(I1, m_Neg(m_Value(Op)))) { + Negs |= 2; + Negs ^= 1; + FNegs.push_back(cast(I0)); + I1 = Op; } ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; - Instruction *CommonOperand; - Instruction *UncommonRealOp; - Instruction *UncommonImagOp; + Value *CommonOperand; + Value *UncommonRealOp; + Value *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; @@ -676,18 +673,14 @@ return nullptr; } - Instruction *R0 = dyn_cast(RealMulI->getOperand(0)); - Instruction *R1 = dyn_cast(RealMulI->getOperand(1)); - Instruction *I0 = dyn_cast(ImagMulI->getOperand(0)); - Instruction *I1 = dyn_cast(ImagMulI->getOperand(1)); - if (!R0 || !R1 || !I0 || !I1) { - LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); - return nullptr; - } + Value *R0 = RealMulI->getOperand(0); + Value *R1 = RealMulI->getOperand(1); + Value *I0 = ImagMulI->getOperand(0); + Value *I1 = ImagMulI->getOperand(1); - Instruction *CommonOperand; - Instruction *UncommonRealOp; - Instruction *UncommonImagOp; + Value *CommonOperand; + Value *UncommonRealOp; + Value *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; @@ -705,7 +698,7 @@ Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); - std::pair PartialMatch( + std::pair PartialMatch( (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) ? CommonOperand @@ -840,11 +833,8 @@ !isInstructionPotentiallySymmetric(Imag)) return nullptr; - auto *R0 = dyn_cast(Real->getOperand(0)); - auto *I0 = dyn_cast(Imag->getOperand(0)); - - if (!R0 || !I0) - return nullptr; + auto *R0 = Real->getOperand(0); + auto *I0 = Imag->getOperand(0); NodePtr Op0 = identifyNode(R0, I0); NodePtr Op1 = nullptr; @@ -852,11 +842,8 @@ return nullptr; if (Real->isBinaryOp()) { - auto *R1 = dyn_cast(Real->getOperand(1)); - auto *I1 = dyn_cast(Imag->getOperand(1)); - if (!R1 || !I1) - return nullptr; - + auto *R1 = Real->getOperand(1); + auto *I1 = Imag->getOperand(1); Op1 = identifyNode(R1, I1); if (Op1 == nullptr) return nullptr; @@ -880,13 +867,21 @@ } ComplexDeinterleavingGraph::NodePtr -ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { - LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); - if (NodePtr CN = getContainingComposite(Real, Imag)) { +ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { + LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); + if (NodePtr CN = getContainingComposite(R, I)) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return CN; } + if (NodePtr CN = identifySplat(R, I)) + return CN; + + auto *Real = dyn_cast(R); + auto *Imag = dyn_cast(I); + if (!Real || !Imag) + return nullptr; + if (NodePtr CN = identifyDeinterleave(Real, Imag)) return CN; @@ -929,6 +924,7 @@ ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, Instruction *Imag) { + if ((Real->getOpcode() != Instruction::FAdd && Real->getOpcode() != Instruction::FSub && Real->getOpcode() != Instruction::FNeg) || @@ -965,8 +961,10 @@ continue; Instruction *I = dyn_cast(V); - if (!I) - return false; + if (!I) { + Addends.emplace_back(V, IsPositive); + continue; + } // If an instruction has more than one user, it indicates that it either // has an external user, which will be later checked by the checkNodes @@ -987,20 +985,18 @@ Worklist.emplace_back(I->getOperand(1), !IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); } else if (I->getOpcode() == Instruction::FMul) { - auto *A = dyn_cast(I->getOperand(0)); - if (A && A->getOpcode() == Instruction::FNeg) { - A = dyn_cast(A->getOperand(0)); + Value *A, *B; + if (match(I->getOperand(0), m_FNeg(m_Value(A)))) { IsPositive = !IsPositive; + } else { + A = I->getOperand(0); } - if (!A) - return false; - auto *B = dyn_cast(I->getOperand(1)); - if (B && B->getOpcode() == Instruction::FNeg) { - B = dyn_cast(B->getOperand(0)); + + if (match(I->getOperand(1), m_FNeg(m_Value(B)))) { IsPositive = !IsPositive; + } else { + B = I->getOperand(1); } - if (!B) - return false; Muls.push_back(Product{A, B, IsPositive}); } else if (I->getOpcode() == Instruction::FNeg) { Worklist.emplace_back(I->getOperand(0), !IsPositive); @@ -1057,7 +1053,7 @@ std::vector &PartialMulCandidates) { // Helper function to extract a common operand from two products auto FindCommonInstruction = [](const Product &Real, - const Product &Imag) -> Instruction * { + const Product &Imag) -> Value * { if (Real.Multiplicand == Imag.Multiplicand || Real.Multiplicand == Imag.Multiplier) return Real.Multiplicand; @@ -1085,18 +1081,17 @@ auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier : ImagMuls[j].Multiplicand; - bool Inverted = false; auto Node = identifyNode(A, B); - if (!Node) { - std::swap(A, B); - Inverted = true; - Node = identifyNode(A, B); + if (Node) { + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, false}); } - if (!Node) - continue; - FoundCommon = true; - PartialMulCandidates.push_back({Common, Node, i, j, Inverted}); + Node = identifyNode(B, A); + if (Node) { + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, true}); + } } if (!FoundCommon) return false; @@ -1116,7 +1111,7 @@ return nullptr; // Map to store common instruction to node pointers - std::map CommonToNode; + std::map CommonToNode; std::vector Processed(Info.size(), false); for (unsigned I = 0; I < Info.size(); ++I) { if (Processed[I]) @@ -1709,6 +1704,59 @@ return submitCompositeNode(PlaceholderNode); } +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) { + auto IsSplat = [](Value *V) -> bool { + // Fixed-width vector with constants + if (isa(V)) + return true; + + VectorType *VTy; + ArrayRef Mask; + // Splats are represented differently depending on whether the repeated + // value is a constant or an Instruction + if (auto *Const = dyn_cast(V)) { + if (Const->getOpcode() != Instruction::ShuffleVector) + return false; + VTy = cast(Const->getType()); + Mask = Const->getShuffleMask(); + } else if (auto *Shuf = dyn_cast(V)) { + VTy = Shuf->getType(); + Mask = Shuf->getShuffleMask(); + } else { + return false; + } + + // When the data type is <1 x Type>, it's not possible to differentiate + // between the ComplexDeinterleaving::Deinterleave and + // ComplexDeinterleaving::Splat operations. + if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1) + return false; + + return all_equal(Mask) && Mask[0] == 0; + }; + + if (!IsSplat(R) || !IsSplat(I)) + return nullptr; + + auto *Real = dyn_cast(R); + auto *Imag = dyn_cast(I); + if ((!Real && Imag) || (Real && !Imag)) + return nullptr; + + if (Real && Imag) { + // Non-constant splats should be in the same basic block + if (Real->getParent() != Imag->getParent()) + return nullptr; + + FinalInstructions.insert(Real); + FinalInstructions.insert(Imag); + } + NodePtr PlaceholderNode = + prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I); + return submitCompositeNode(PlaceholderNode); +} + ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, Instruction *Imag) { @@ -1819,6 +1867,25 @@ case ComplexDeinterleavingOperation::Deinterleave: llvm_unreachable("Deinterleave node should already have ReplacementNode"); break; + case ComplexDeinterleavingOperation::Splat: { + auto *NewTy = VectorType::getDoubleElementsVectorType( + cast(Node->Real->getType())); + auto *R = dyn_cast(Node->Real); + auto *I = dyn_cast(Node->Imag); + if (R && I) { + // Splats that are not constant are interleaved where they are located + Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode(); + IRBuilder<> IRB(InsertPoint); + ReplacementNode = + IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy, + {Node->Real, Node->Imag}); + } else { + ReplacementNode = + Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, + NewTy, {Node->Real, Node->Imag}); + } + break; + } case ComplexDeinterleavingOperation::ReductionPHI: { // If Operation is ReductionPHI, a new empty PHINode is created. // It is filled later when the ReductionOperation is processed. @@ -1834,8 +1901,8 @@ processReductionOperation(ReplacementNode, Node); break; case ComplexDeinterleavingOperation::ReductionSelect: { - auto *MaskReal = Node->Real->getOperand(0); - auto *MaskImag = Node->Imag->getOperand(0); + auto *MaskReal = cast(Node->Real)->getOperand(0); + auto *MaskImag = cast(Node->Imag)->getOperand(0); auto *A = replaceNode(Builder, Node->Operands[0]); auto *B = replaceNode(Builder, Node->Operands[1]); auto *NewMaskTy = VectorType::getDoubleElementsVectorType( @@ -1860,11 +1927,13 @@ void ComplexDeinterleavingGraph::processReductionOperation( Value *OperationReplacement, RawNodePtr Node) { - auto *OldPHIReal = ReductionInfo[Node->Real].first; - auto *OldPHIImag = ReductionInfo[Node->Imag].first; + auto *Real = cast(Node->Real); + auto *Imag = cast(Node->Imag); + auto *OldPHIReal = ReductionInfo[Real].first; + auto *OldPHIImag = ReductionInfo[Imag].first; auto *NewPHI = OldToNewPHI[OldPHIReal]; - auto *VTy = cast(Node->Real->getType()); + auto *VTy = cast(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); // We have to interleave initial origin values coming from IncomingBlock @@ -1880,8 +1949,8 @@ // Deinterleave complex vector outside of loop so that it can be finally // reduced - auto *FinalReductionReal = ReductionInfo[Node->Real].second; - auto *FinalReductionImag = ReductionInfo[Node->Imag].second; + auto *FinalReductionReal = ReductionInfo[Real].second; + auto *FinalReductionImag = ReductionInfo[Imag].second; Builder.SetInsertPoint( &*FinalReductionReal->getParent()->getFirstInsertionPt()); @@ -1890,11 +1959,11 @@ OperationReplacement->getType(), OperationReplacement); auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); - FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal); + FinalReductionReal->replaceUsesOfWith(Real, NewReal); Builder.SetInsertPoint(FinalReductionImag); auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); - FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag); + FinalReductionImag->replaceUsesOfWith(Imag, NewImag); } void ComplexDeinterleavingGraph::replaceNodes() { @@ -1911,10 +1980,12 @@ if (RootNode->Operation == ComplexDeinterleavingOperation::ReductionOperation) { - ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge); - ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge); - DeadInstrRoots.push_back(RootNode->Real); - DeadInstrRoots.push_back(RootNode->Imag); + auto *RootReal = cast(RootNode->Real); + auto *RootImag = cast(RootNode->Imag); + ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); + ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); + DeadInstrRoots.push_back(cast(RootReal)); + DeadInstrRoots.push_back(cast(RootImag)); } else { assert(R && "Unable to find replacement for RootInstruction"); DeadInstrRoots.push_back(RootInstruction); diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll @@ -0,0 +1,109 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s --mattr=+sve -o - | FileCheck %s + +target triple = "aarch64-arm-none-eabi" + +; a[i] * b[i] * (11.0 + 3.0.i); +; +define @complex_mul_const( %a, %b) { +; CHECK-LABEL: complex_mul_const: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z4.d, #0 // =0x0 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z5.d, z4.d +; CHECK-NEXT: mov z6.d, z4.d +; CHECK-NEXT: fcmla z5.d, p0/m, z0.d, z2.d, #0 +; CHECK-NEXT: fcmla z6.d, p0/m, z1.d, z3.d, #0 +; CHECK-NEXT: fcmla z5.d, p0/m, z0.d, z2.d, #90 +; CHECK-NEXT: fcmla z6.d, p0/m, z1.d, z3.d, #90 +; CHECK-NEXT: fmov z1.d, #3.00000000 +; CHECK-NEXT: fmov z2.d, #11.00000000 +; CHECK-NEXT: zip2 z3.d, z2.d, z1.d +; CHECK-NEXT: mov z0.d, z4.d +; CHECK-NEXT: zip1 z1.d, z2.d, z1.d +; CHECK-NEXT: fcmla z4.d, p0/m, z6.d, z3.d, #0 +; CHECK-NEXT: fcmla z0.d, p0/m, z5.d, z1.d, #0 +; CHECK-NEXT: fcmla z4.d, p0/m, z6.d, z3.d, #90 +; CHECK-NEXT: fcmla z0.d, p0/m, z5.d, z1.d, #90 +; CHECK-NEXT: mov z1.d, z4.d +; CHECK-NEXT: ret +entry: + %strided.vec = tail call { , } @llvm.experimental.vector.deinterleave2.nxv4f64( %a) + %0 = extractvalue { , } %strided.vec, 0 + %1 = extractvalue { , } %strided.vec, 1 + %strided.vec48 = tail call { , } @llvm.experimental.vector.deinterleave2.nxv4f64( %b) + %2 = extractvalue { , } %strided.vec48, 0 + %3 = extractvalue { , } %strided.vec48, 1 + %4 = fmul fast %3, %0 + %5 = fmul fast %2, %1 + %6 = fadd fast %4, %5 + %7 = fmul fast %2, %0 + %8 = fmul fast %3, %1 + %9 = fsub fast %7, %8 + %10 = fmul fast %9, shufflevector ( insertelement ( poison, double 3.000000e+00, i64 0), poison, zeroinitializer) + %11 = fmul fast %6, shufflevector ( insertelement ( poison, double 1.100000e+01, i64 0), poison, zeroinitializer) + %12 = fadd fast %10, %11 + %13 = fmul fast %9, shufflevector ( insertelement ( poison, double 1.100000e+01, i64 0), poison, zeroinitializer) + %14 = fmul fast %6, shufflevector ( insertelement ( poison, double 3.000000e+00, i64 0), poison, zeroinitializer) + %15 = fsub fast %13, %14 + %interleaved.vec = tail call @llvm.experimental.vector.interleave2.nxv4f64( %15, %12) + ret %interleaved.vec +} + +; a[i] * b[i] * c; +; +define @complex_mul_non_const( %a, %b, [2 x double] %c) { +; CHECK-LABEL: complex_mul_non_const: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: mov z6.d, #0 // =0x0 +; CHECK-NEXT: // kill: def $d5 killed $d5 def $z5 +; CHECK-NEXT: // kill: def $d4 killed $d4 def $z4 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: mov z7.d, z6.d +; CHECK-NEXT: mov z24.d, z6.d +; CHECK-NEXT: mov z5.d, d5 +; CHECK-NEXT: mov z4.d, d4 +; CHECK-NEXT: fcmla z7.d, p0/m, z0.d, z2.d, #0 +; CHECK-NEXT: fcmla z24.d, p0/m, z1.d, z3.d, #0 +; CHECK-NEXT: fcmla z7.d, p0/m, z0.d, z2.d, #90 +; CHECK-NEXT: zip2 z2.d, z4.d, z5.d +; CHECK-NEXT: fcmla z24.d, p0/m, z1.d, z3.d, #90 +; CHECK-NEXT: mov z0.d, z6.d +; CHECK-NEXT: zip1 z4.d, z4.d, z5.d +; CHECK-NEXT: fcmla z6.d, p0/m, z24.d, z2.d, #0 +; CHECK-NEXT: fcmla z0.d, p0/m, z7.d, z4.d, #0 +; CHECK-NEXT: fcmla z6.d, p0/m, z24.d, z2.d, #90 +; CHECK-NEXT: fcmla z0.d, p0/m, z7.d, z4.d, #90 +; CHECK-NEXT: mov z1.d, z6.d +; CHECK-NEXT: ret +entry: + %c.coerce.fca.0.extract = extractvalue [2 x double] %c, 0 + %c.coerce.fca.1.extract = extractvalue [2 x double] %c, 1 + %broadcast.splatinsert = insertelement poison, double %c.coerce.fca.1.extract, i64 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %broadcast.splatinsert49 = insertelement poison, double %c.coerce.fca.0.extract, i64 0 + %broadcast.splat50 = shufflevector %broadcast.splatinsert49, poison, zeroinitializer + %strided.vec = tail call { , } @llvm.experimental.vector.deinterleave2.nxv4f64( %a) + %0 = extractvalue { , } %strided.vec, 0 + %1 = extractvalue { , } %strided.vec, 1 + %strided.vec48 = tail call { , } @llvm.experimental.vector.deinterleave2.nxv4f64( %b) + %2 = extractvalue { , } %strided.vec48, 0 + %3 = extractvalue { , } %strided.vec48, 1 + %4 = fmul fast %3, %0 + %5 = fmul fast %2, %1 + %6 = fadd fast %4, %5 + %7 = fmul fast %2, %0 + %8 = fmul fast %3, %1 + %9 = fsub fast %7, %8 + %10 = fmul fast %9, %broadcast.splat + %11 = fmul fast %6, %broadcast.splat50 + %12 = fadd fast %10, %11 + %13 = fmul fast %9, %broadcast.splat50 + %14 = fmul fast %6, %broadcast.splat + %15 = fsub fast %13, %14 + %interleaved.vec = tail call @llvm.experimental.vector.interleave2.nxv4f64( %15, %12) + ret %interleaved.vec +} + +declare { , } @llvm.experimental.vector.deinterleave2.nxv4f64() +declare @llvm.experimental.vector.interleave2.nxv4f64(, ) diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll @@ -0,0 +1,97 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s --mattr=+complxnum -o - | FileCheck %s + +target triple = "aarch64-arm-none-eabi" + + +; a[i] * b[i] * (11.0 + 3.0.i); +; +define <4 x double> @complex_mul_const(<4 x double> %a, <4 x double> %b) { +; CHECK-LABEL: complex_mul_const: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v6.2d, #0000000000000000 +; CHECK-NEXT: adrp x8, .LCPI0_0 +; CHECK-NEXT: movi v5.2d, #0000000000000000 +; CHECK-NEXT: movi v4.2d, #0000000000000000 +; CHECK-NEXT: fcmla v6.2d, v3.2d, v1.2d, #0 +; CHECK-NEXT: fcmla v5.2d, v2.2d, v0.2d, #0 +; CHECK-NEXT: fcmla v6.2d, v3.2d, v1.2d, #90 +; CHECK-NEXT: fcmla v5.2d, v2.2d, v0.2d, #90 +; CHECK-NEXT: ldr q2, [x8, :lo12:.LCPI0_0] +; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: fcmla v4.2d, v2.2d, v6.2d, #0 +; CHECK-NEXT: fcmla v0.2d, v2.2d, v5.2d, #0 +; CHECK-NEXT: fcmla v4.2d, v2.2d, v6.2d, #90 +; CHECK-NEXT: fcmla v0.2d, v2.2d, v5.2d, #90 +; CHECK-NEXT: mov v1.16b, v4.16b +; CHECK-NEXT: ret +entry: + %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> + %strided.vec47 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> + %strided.vec49 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> + %strided.vec50 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> + %0 = fmul fast <2 x double> %strided.vec50, %strided.vec + %1 = fmul fast <2 x double> %strided.vec49, %strided.vec47 + %2 = fadd fast <2 x double> %0, %1 + %3 = fmul fast <2 x double> %strided.vec49, %strided.vec + %4 = fmul fast <2 x double> %strided.vec50, %strided.vec47 + %5 = fsub fast <2 x double> %3, %4 + %6 = fmul fast <2 x double> %5, + %7 = fmul fast <2 x double> %2, + %8 = fadd fast <2 x double> %6, %7 + %9 = fmul fast <2 x double> %5, + %10 = fmul fast <2 x double> %2, + %11 = fsub fast <2 x double> %9, %10 + %interleaved.vec = shufflevector <2 x double> %11, <2 x double> %8, <4 x i32> + ret <4 x double> %interleaved.vec +} + + +; a[i] * b[i] * c; +; +define <4 x double> @complex_mul_non_const(<4 x double> %a, <4 x double> %b, [2 x double] %c) { +; CHECK-LABEL: complex_mul_non_const: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v6.2d, #0000000000000000 +; CHECK-NEXT: // kill: def $d4 killed $d4 def $q4 +; CHECK-NEXT: // kill: def $d5 killed $d5 def $q5 +; CHECK-NEXT: movi v7.2d, #0000000000000000 +; CHECK-NEXT: mov v4.d[1], v5.d[0] +; CHECK-NEXT: fcmla v6.2d, v2.2d, v0.2d, #0 +; CHECK-NEXT: fcmla v7.2d, v3.2d, v1.2d, #0 +; CHECK-NEXT: fcmla v6.2d, v2.2d, v0.2d, #90 +; CHECK-NEXT: movi v2.2d, #0000000000000000 +; CHECK-NEXT: fcmla v7.2d, v3.2d, v1.2d, #90 +; CHECK-NEXT: movi v0.2d, #0000000000000000 +; CHECK-NEXT: fcmla v2.2d, v4.2d, v7.2d, #0 +; CHECK-NEXT: fcmla v0.2d, v4.2d, v6.2d, #0 +; CHECK-NEXT: fcmla v2.2d, v4.2d, v7.2d, #90 +; CHECK-NEXT: fcmla v0.2d, v4.2d, v6.2d, #90 +; CHECK-NEXT: mov v1.16b, v2.16b +; CHECK-NEXT: ret +entry: + %c.coerce.fca.1.extract = extractvalue [2 x double] %c, 1 + %c.coerce.fca.0.extract = extractvalue [2 x double] %c, 0 + %broadcast.splatinsert = insertelement <2 x double> poison, double %c.coerce.fca.1.extract, i64 0 + %broadcast.splat = shufflevector <2 x double> %broadcast.splatinsert, <2 x double> poison, <2 x i32> zeroinitializer + %broadcast.splatinsert51 = insertelement <2 x double> poison, double %c.coerce.fca.0.extract, i64 0 + %broadcast.splat52 = shufflevector <2 x double> %broadcast.splatinsert51, <2 x double> poison, <2 x i32> zeroinitializer + %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> + %strided.vec47 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> + %strided.vec49 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> + %strided.vec50 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> + %0 = fmul fast <2 x double> %strided.vec50, %strided.vec + %1 = fmul fast <2 x double> %strided.vec49, %strided.vec47 + %2 = fadd fast <2 x double> %0, %1 + %3 = fmul fast <2 x double> %strided.vec49, %strided.vec + %4 = fmul fast <2 x double> %strided.vec50, %strided.vec47 + %5 = fsub fast <2 x double> %3, %4 + %6 = fmul fast <2 x double> %5, %broadcast.splat + %7 = fmul fast <2 x double> %2, %broadcast.splat52 + %8 = fadd fast <2 x double> %6, %7 + %9 = fmul fast <2 x double> %5, %broadcast.splat52 + %10 = fmul fast <2 x double> %2, %broadcast.splat + %11 = fsub fast <2 x double> %9, %10 + %interleaved.vec = shufflevector <2 x double> %11, <2 x double> %8, <4 x i32> + ret <4 x double> %interleaved.vec +}