diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -137,19 +137,12 @@ Instruction *Real; Instruction *Imag; - // Instructions that should only exist within this node, there should be no - // users of these instructions outside the node. An example of these would be - // the multiply instructions of a partial multiply operation. - SmallVector InternalInstructions; ComplexDeinterleavingRotation Rotation; SmallVector Operands; Value *ReplacementNode = nullptr; - void addInstruction(Instruction *I) { InternalInstructions.push_back(I); } void addOperand(NodePtr Node) { Operands.push_back(Node.get()); } - bool hasAllInternalUses(SmallPtrSet &AllInstructions); - void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { auto PrintValue = [&](Value *V) { @@ -181,12 +174,6 @@ OS << " - "; PrintNodeRef(Op); } - OS << " InternalInstructions:\n"; - for (const auto &I : InternalInstructions) { - OS << " - \""; - I->print(OS, true); - OS << "\"\n"; - } } }; @@ -194,14 +181,22 @@ public: using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; - explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {} + explicit ComplexDeinterleavingGraph(const TargetLowering *TL, + const TargetLibraryInfo *TLI) + : TL(TL), TLI(TLI) {} private: const TargetLowering *TL; - Instruction *RootValue; - NodePtr RootNode; + const TargetLibraryInfo *TLI; SmallVector CompositeNodes; - SmallPtrSet AllInstructions; + + SmallPtrSet FinalInstructions; + + /// Root instructions are instructions from which complex computation starts + std::map RootToNode; + + /// Topologically sorted root instructions + SmallVector OrderedRoots; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, Instruction *R, Instruction *I) { @@ -211,10 +206,6 @@ NodePtr submitCompositeNode(NodePtr Node) { CompositeNodes.push_back(Node); - AllInstructions.insert(Node->Real); - AllInstructions.insert(Node->Imag); - for (auto *I : Node->InternalInstructions) - AllInstructions.insert(I); return Node; } @@ -271,6 +262,10 @@ /// current graph. bool identifyNodes(Instruction *RootI); + /// Check that every instruction, from the roots to the leaves, has internal + /// uses. + bool checkNodes(); + /// Perform the actual replacement of the underlying instruction graph. void replaceNodes(); }; @@ -368,9 +363,7 @@ } bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { - bool Changed = false; - - SmallVector DeadInstrRoots; + ComplexDeinterleavingGraph Graph(TL, TLI); for (auto &I : *B) { auto *SVI = dyn_cast(&I); @@ -382,22 +375,15 @@ if (!isInterleavingMask(SVI->getShuffleMask())) continue; - ComplexDeinterleavingGraph Graph(TL); - if (!Graph.identifyNodes(SVI)) - continue; - - Graph.replaceNodes(); - DeadInstrRoots.push_back(SVI); - Changed = true; + Graph.identifyNodes(SVI); } - for (const auto &I : DeadInstrRoots) { - if (!I || I->getParent() == nullptr) - continue; - llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI); + if (Graph.checkNodes()) { + Graph.replaceNodes(); + return true; } - return Changed; + return false; } ComplexDeinterleavingGraph::NodePtr @@ -511,7 +497,6 @@ Node->Rotation = Rotation; Node->addOperand(CommonNode); Node->addOperand(UncommonNode); - Node->InternalInstructions.append(FNegs); return submitCompositeNode(Node); } @@ -627,8 +612,6 @@ NodePtr Node = prepareCompositeNode( ComplexDeinterleavingOperation::CMulPartial, Real, Imag); - Node->addInstruction(RealMulI); - Node->addInstruction(ImagMulI); Node->Rotation = Rotation; Node->addOperand(CommonRes); Node->addOperand(UncommonRes); @@ -846,6 +829,8 @@ prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle, RealShuffle, ImagShuffle); PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0); + FinalInstructions.insert(RealShuffle); + FinalInstructions.insert(ImagShuffle); return submitCompositeNode(PlaceholderNode); } if (RealShuffle || ImagShuffle) { @@ -881,9 +866,7 @@ if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag)))) return false; - RootValue = RootI; - AllInstructions.insert(RootI); - RootNode = identifyNode(Real, Imag); + auto RootNode = identifyNode(Real, Imag); LLVM_DEBUG({ Function *F = RootI->getFunction(); @@ -894,14 +877,99 @@ dbgs() << "\n"; }); - // Check all instructions have internal uses - for (const auto &Node : CompositeNodes) { - if (!Node->hasAllInternalUses(AllInstructions)) { - LLVM_DEBUG(dbgs() << " - Invalid internal uses\n"); - return false; + if (RootNode) { + RootToNode[RootI] = RootNode; + OrderedRoots.push_back(RootI); + return true; + } + + return false; +} + +bool ComplexDeinterleavingGraph::checkNodes() { + // Collect all instructions from roots to leaves + SmallPtrSet AllInstructions; + SmallVector ToDo; + for (auto *I : OrderedRoots) + ToDo.push_back(I); + + // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG + // chains + while (!ToDo.empty()) { + auto *I = ToDo.back(); + ToDo.pop_back(); + + if (!AllInstructions.insert(I).second) + continue; + + if (!FinalInstructions.count(I)) { + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast(Op)) + ToDo.emplace_back(OpI); + } } } - return RootNode != nullptr; + + // Find instructions that have users outside of chain + SmallVector OuterInstructions; + for (auto *I : AllInstructions) { + // Skip root nodes + if (RootToNode.count(I)) + continue; + + for (User *U : I->users()) { + if (auto *OpI = dyn_cast(U)) { + if (AllInstructions.count(OpI)) + continue; + + // Found an instruction that is not used by XCMLA/XCADD chain + OuterInstructions.emplace_back(OpI); + } + } + } + + // If any instructions are found to be used outside, find and remove roots + // that somehow connect to those instructions. + SmallPtrSet Visited; + for (Instruction *I : OuterInstructions) { + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast(Op)) + ToDo.emplace_back(OpI); + } + } + + while (!ToDo.empty()) { + auto *I = ToDo.back(); + ToDo.pop_back(); + if (!Visited.insert(I).second) + continue; + + // Found an impacted root node. Removing it from the nodes to be + // deinterleaved + if (RootToNode.count(I)) { + LLVM_DEBUG(dbgs() << "Instruction " << *I + << " could be deinterleaved but its chain of complex " + "operations have an outside user\n"); + RootToNode.erase(I); + } + + if (!AllInstructions.count(I)) + continue; + + for (User *U : I->users()) { + if (auto *OpI = dyn_cast(U)) + ToDo.emplace_back(OpI); + } + + if (FinalInstructions.count(I)) + continue; + + for (Value *Op : I->operands()) { + if (auto *OpI = dyn_cast(Op)) + ToDo.emplace_back(OpI); + } + } + return !RootToNode.empty(); } static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node, @@ -958,29 +1026,21 @@ } void ComplexDeinterleavingGraph::replaceNodes() { - Value *R = replaceNode(RootNode.get()); - assert(R && "Unable to find replacement for RootValue"); - RootValue->replaceAllUsesWith(R); -} - -bool ComplexDeinterleavingCompositeNode::hasAllInternalUses( - SmallPtrSet &AllInstructions) { - if (Operation == ComplexDeinterleavingOperation::Shuffle) - return true; + SmallVector DeadInstrRoots; + for (auto *RootInstruction : OrderedRoots) { + // Check if this potential root went through check process and we can + // deinterleave it + if (!RootToNode.count(RootInstruction)) + continue; - for (auto *User : Real->users()) { - if (!AllInstructions.contains(cast(User))) - return false; + IRBuilder<> Builder(RootInstruction); + auto RootNode = RootToNode[RootInstruction]; + Value *R = replaceNode(RootNode.get()); + assert(R && "Unable to find replacement for RootInstruction"); + DeadInstrRoots.push_back(RootInstruction); + RootInstruction->replaceAllUsesWith(R); } - for (auto *User : Imag->users()) { - if (!AllInstructions.contains(cast(User))) - return false; - } - for (auto *I : InternalInstructions) { - for (auto *User : I->users()) { - if (!AllInstructions.contains(cast(User))) - return false; - } - } - return true; + + for (auto *I : DeadInstrRoots) + RecursivelyDeleteTriviallyDeadInstructions(I, TLI); } diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll @@ -0,0 +1,161 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s --mattr=+complxnum,+neon,+fullfp16 -o - | FileCheck %s + +target triple = "aarch64-arm-none-eabi" +; Expected to transform +; *p = (a * b); +; return (a * b) * a +define <4 x float> @mul_triangle(<4 x float> %a, <4 x float> %b, ptr %p) { +; CHECK-LABEL: mul_triangle: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: movi v3.2d, #0000000000000000 +; CHECK-NEXT: movi v2.2d, #0000000000000000 +; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #0 +; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #90 +; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #0 +; CHECK-NEXT: str q3, [x0] +; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #90 +; CHECK-NEXT: mov v0.16b, v2.16b +; CHECK-NEXT: ret +entry: + %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> + %strided.vec35 = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> + %strided.vec37 = shufflevector <4 x float> %b, <4 x float> poison, <2 x i32> + %strided.vec38 = shufflevector <4 x float> %b, <4 x float> poison, <2 x i32> + %0 = fmul fast <2 x float> %strided.vec37, %strided.vec + %1 = fmul fast <2 x float> %strided.vec38, %strided.vec35 + %2 = fsub fast <2 x float> %0, %1 + %3 = fmul fast <2 x float> %2, %strided.vec35 + %4 = fmul fast <2 x float> %strided.vec38, %strided.vec + %5 = fmul fast <2 x float> %strided.vec35, %strided.vec37 + %6 = fadd fast <2 x float> %4, %5 + %otheruse = shufflevector <2 x float> %2, <2 x float> %6, <4 x i32> + store <4 x float> %otheruse, ptr %p + %7 = fmul fast <2 x float> %6, %strided.vec + %8 = fadd fast <2 x float> %3, %7 + %9 = fmul fast <2 x float> %2, %strided.vec + %10 = fmul fast <2 x float> %6, %strided.vec35 + %11 = fsub fast <2 x float> %9, %10 + %interleaved.vec = shufflevector <2 x float> %11, <2 x float> %8, <4 x i32> + ret <4 x float> %interleaved.vec +} + +; Expected to not transform +; *p = (a * b).real(); +; return (a * b) * a +define <4 x float> @mul_triangle_external_use(<4 x float> %a, <4 x float> %b, ptr %p) { +; CHECK-LABEL: mul_triangle_external_use: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8 +; CHECK-NEXT: zip2 v4.2s, v0.2s, v2.2s +; CHECK-NEXT: zip1 v0.2s, v0.2s, v2.2s +; CHECK-NEXT: zip1 v5.2s, v1.2s, v3.2s +; CHECK-NEXT: zip2 v1.2s, v1.2s, v3.2s +; CHECK-NEXT: fmul v2.2s, v4.2s, v5.2s +; CHECK-NEXT: fmul v3.2s, v1.2s, v4.2s +; CHECK-NEXT: fmla v2.2s, v0.2s, v1.2s +; CHECK-NEXT: fneg v1.2s, v3.2s +; CHECK-NEXT: fmul v3.2s, v2.2s, v4.2s +; CHECK-NEXT: fmla v1.2s, v0.2s, v5.2s +; CHECK-NEXT: fmul v5.2s, v2.2s, v0.2s +; CHECK-NEXT: str d2, [x0] +; CHECK-NEXT: fneg v3.2s, v3.2s +; CHECK-NEXT: fmla v5.2s, v4.2s, v1.2s +; CHECK-NEXT: fmla v3.2s, v0.2s, v1.2s +; CHECK-NEXT: zip1 v0.4s, v3.4s, v5.4s +; CHECK-NEXT: ret +entry: + %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> + %strided.vec35 = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> + %strided.vec37 = shufflevector <4 x float> %b, <4 x float> poison, <2 x i32> + %strided.vec38 = shufflevector <4 x float> %b, <4 x float> poison, <2 x i32> + %0 = fmul fast <2 x float> %strided.vec37, %strided.vec + %1 = fmul fast <2 x float> %strided.vec38, %strided.vec35 + %2 = fsub fast <2 x float> %0, %1 + %3 = fmul fast <2 x float> %2, %strided.vec35 + %4 = fmul fast <2 x float> %strided.vec38, %strided.vec + %5 = fmul fast <2 x float> %strided.vec35, %strided.vec37 + %6 = fadd fast <2 x float> %4, %5 + store <2 x float> %6, ptr %p + %7 = fmul fast <2 x float> %6, %strided.vec + %8 = fadd fast <2 x float> %3, %7 + %9 = fmul fast <2 x float> %2, %strided.vec + %10 = fmul fast <2 x float> %6, %strided.vec35 + %11 = fsub fast <2 x float> %9, %10 + %interleaved.vec = shufflevector <2 x float> %11, <2 x float> %8, <4 x i32> + ret <4 x float> %interleaved.vec +} + +; Expected to not transform. Shows that external use prevents deinterleaving whole chain. +; *p1 = (a * b).real(); +; *p2 = (a * b) * c; +; return d * c +define <4 x float> @monster(<4 x float> %a, <4 x float> %b, <4 x float> %c, <4 x float> %d, ptr %p1, ptr %p2) { +; CHECK-LABEL: monster: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: ext v5.16b, v0.16b, v0.16b, #8 +; CHECK-NEXT: ext v6.16b, v1.16b, v1.16b, #8 +; CHECK-NEXT: ext v4.16b, v3.16b, v3.16b, #8 +; CHECK-NEXT: ext v7.16b, v2.16b, v2.16b, #8 +; CHECK-NEXT: zip2 v16.2s, v0.2s, v5.2s +; CHECK-NEXT: zip1 v17.2s, v1.2s, v6.2s +; CHECK-NEXT: zip2 v1.2s, v1.2s, v6.2s +; CHECK-NEXT: zip1 v0.2s, v0.2s, v5.2s +; CHECK-NEXT: zip2 v18.2s, v3.2s, v4.2s +; CHECK-NEXT: zip2 v6.2s, v2.2s, v7.2s +; CHECK-NEXT: zip1 v2.2s, v2.2s, v7.2s +; CHECK-NEXT: zip1 v3.2s, v3.2s, v4.2s +; CHECK-NEXT: fmul v5.2s, v17.2s, v16.2s +; CHECK-NEXT: fmul v16.2s, v1.2s, v16.2s +; CHECK-NEXT: fmul v4.2s, v18.2s, v6.2s +; CHECK-NEXT: fmul v7.2s, v3.2s, v6.2s +; CHECK-NEXT: fmla v5.2s, v0.2s, v1.2s +; CHECK-NEXT: fneg v1.2s, v16.2s +; CHECK-NEXT: fneg v4.2s, v4.2s +; CHECK-NEXT: fmla v7.2s, v2.2s, v18.2s +; CHECK-NEXT: fmla v1.2s, v0.2s, v17.2s +; CHECK-NEXT: fmul v17.2s, v2.2s, v5.2s +; CHECK-NEXT: fmul v0.2s, v6.2s, v5.2s +; CHECK-NEXT: fmla v4.2s, v2.2s, v3.2s +; CHECK-NEXT: fmla v17.2s, v1.2s, v6.2s +; CHECK-NEXT: str d1, [x0] +; CHECK-NEXT: fneg v16.2s, v0.2s +; CHECK-NEXT: zip1 v0.4s, v4.4s, v7.4s +; CHECK-NEXT: fmla v16.2s, v1.2s, v2.2s +; CHECK-NEXT: st2 { v16.2s, v17.2s }, [x1] +; CHECK-NEXT: ret +entry: + %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> + %strided.vec88 = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> + %strided.vec90 = shufflevector <4 x float> %b, <4 x float> poison, <2 x i32> + %strided.vec91 = shufflevector <4 x float> %b, <4 x float> poison, <2 x i32> + %0 = fmul fast <2 x float> %strided.vec91, %strided.vec + %1 = fmul fast <2 x float> %strided.vec90, %strided.vec88 + %2 = fadd fast <2 x float> %0, %1 + %3 = fmul fast <2 x float> %strided.vec90, %strided.vec + %4 = fmul fast <2 x float> %strided.vec91, %strided.vec88 + %5 = fsub fast <2 x float> %3, %4 + store <2 x float> %5, ptr %p1 + %strided.vec93 = shufflevector <4 x float> %c, <4 x float> poison, <2 x i32> + %strided.vec94 = shufflevector <4 x float> %c, <4 x float> poison, <2 x i32> + %6 = fmul fast <2 x float> %strided.vec94, %5 + %7 = fmul fast <2 x float> %strided.vec93, %2 + %8 = fadd fast <2 x float> %6, %7 + %9 = fmul fast <2 x float> %strided.vec93, %5 + %10 = fmul fast <2 x float> %strided.vec94, %2 + %11 = fsub fast <2 x float> %9, %10 + %interleaved.vec = shufflevector <2 x float> %11, <2 x float> %8, <4 x i32> + store <4 x float> %interleaved.vec, ptr %p2 + %strided.vec96 = shufflevector <4 x float> %d, <4 x float> poison, <2 x i32> + %strided.vec97 = shufflevector <4 x float> %d, <4 x float> poison, <2 x i32> + %12 = fmul fast <2 x float> %strided.vec96, %strided.vec94 + %13 = fmul fast <2 x float> %strided.vec97, %strided.vec93 + %14 = fadd fast <2 x float> %13, %12 + %15 = fmul fast <2 x float> %strided.vec96, %strided.vec93 + %16 = fmul fast <2 x float> %strided.vec97, %strided.vec94 + %17 = fsub fast <2 x float> %15, %16 + %interleaved.vec98 = shufflevector <2 x float> %17, <2 x float> %14, <4 x i32> + ret <4 x float> %interleaved.vec98 +} +