diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -42,6 +42,19 @@ // step, it is expected to finish successfully, while any errors should be // caught via asserts. // +// Vector Splitting: +// Vector Splitting is only employed if the resulting vectors would exceed the +// width of a single vector register, and is mostly performed between the +// Identification and Replacement steps. Each composite node is cloned and +// annotated with a split index; the original being set to 0, and the clone +// being set to 1. This index identifies whether the given node is operating on +// the lower or higher portion of the original vector, and is used to restrict +// CompositeNode lookups to the same side of the split. The vector splitting is +// responsible for intercepting the loads with shuffles to get only the relevant +// data for that split (e.g. Elements <0, 1, 2, 3> from an 8x vector for the +// lower split), as well as rejoining the 2 split graphs into one at the end, +// through the use of a concatenating shuffle. +// //===----------------------------------------------------------------------===// #include "llvm/CodeGen/ComplexDeinterleavingPass.h" @@ -134,6 +147,11 @@ friend class ComplexDeinterleavingGraph; public: + void SetSplit(unsigned Idx) { + HasSplit = true; + SplitIdx = Idx; + } + SmallVector getOperands() { SmallVector Ops; @@ -158,7 +176,6 @@ Value *OriginalInput0 = nullptr; Value *OriginalInput1 = nullptr; Value *ReplacementNode = nullptr; - bool IsTopLevel = false; ComplexDeinterleavingOperation Operation; bool UsesNegation = false; @@ -168,6 +185,9 @@ Value *Accumulator = nullptr; Value *Accumulatee = nullptr; + bool HasSplit = false; + int SplitIdx = -1; + void addInstruction(Instruction *I) { ContainedInstructions.push_back(I); } bool contains(Instruction *I) { if (I == ReplacementNode) @@ -178,9 +198,10 @@ }; class ComplexDeinterleavingGraph { -private: +public: using NodePtr = std::shared_ptr; +private: SmallVector Instructions; SmallVector CompositeNodes; @@ -188,6 +209,11 @@ llvm::TargetTransformInfo::TCK_Latency; InstructionCost CostOfIntrinsics; + bool NeedsSplit = false; + + std::map ShuffleMapping; + Value *splitLoadIfNecessary(ComplexDeinterleavingGraph::NodePtr Node, + Value *V); /// Determines the operating component of the given Value. /// This is achieved by looking at the operating component of the Value's @@ -333,7 +359,7 @@ return V; } - Value *getFinalInputReplacement(Instruction *I) { + Value *getFinalInputReplacement(Instruction *I, int SplitIdx = -1) { for (Value *V : I->operands()) { auto *Op = dyn_cast(V); while (Op && shouldIgnoreValue(Op)) @@ -341,7 +367,7 @@ if (Op == nullptr) continue; - auto CN = getContainingComposite(Op); + auto CN = getContainingComposite(Op, SplitIdx); if (CN == nullptr || CN->ReplacementNode == nullptr) continue; return followUseChain(CN->ReplacementNode); @@ -364,6 +390,23 @@ return std::make_shared(Operation); } + std::shared_ptr + cloneCompositeNode(NodePtr OtherNode) { + auto NewNode = prepareCompositeNode(OtherNode->Operation); + + NewNode->ContainedInstructions.append(OtherNode->ContainedInstructions); + NewNode->OutputNode = OtherNode->OutputNode; + NewNode->OriginalInput0 = OtherNode->OriginalInput0; + NewNode->OriginalInput1 = OtherNode->OriginalInput1; + + NewNode->UsesNegation = OtherNode->UsesNegation; + NewNode->Rotation = OtherNode->Rotation; + NewNode->Accumulator = OtherNode->Accumulator; + NewNode->Accumulatee = OtherNode->Accumulatee; + + return NewNode; + } + void submitCompositeNode(std::shared_ptr CN) { CompositeNodes.push_back(CN); @@ -417,10 +460,13 @@ return haveSharedUses(A, B); } - NodePtr getContainingComposite(Instruction *I) { + NodePtr getContainingComposite(Instruction *I, int SplitIdx = -1) { if (I == nullptr) return nullptr; for (const auto &CN : CompositeNodes) { + if (SplitIdx > -1 && CN->SplitIdx != SplitIdx) + continue; + if (CN->contains(I)) return CN; if (CN->ReplacementNode == I) @@ -445,6 +491,11 @@ /// Returns false if the deinterleaving operation should be cancelled for the /// current graph. bool identifyNodes(const TargetLowering *TL); + + /// If necessary, splits the nodes so that the operations can fit within a + /// single vector + void splitNodes(const TargetLowering *TL); + /// Perform the actual replacement of the underlying instruction graph. /// Returns false if the deinterleaving operation should be cancelled for the /// current graph. @@ -585,12 +636,16 @@ FixedVectorType::get(VTy->getScalarType(), VTy->getNumElements() * 2); if (!TL->isComplexDeinterleavingOperationSupported( - ComplexDeinterleavingOperation::CMulPartial, NewVTy)) - return false; + ComplexDeinterleavingOperation::CMulPartial, NewVTy)) { + if (!TL->isComplexDeinterleavingOperationSupported( + ComplexDeinterleavingOperation::CMulPartial, VTy)) + return false; - LLVM_DEBUG(dbgs() << "Composite node built up from "; N->dump()); - auto CN = - prepareCompositeNode(llvm::ComplexDeinterleavingOperation::CMulPartial); + NeedsSplit = true; + } + + auto CN = prepareCompositeNode( + llvm::ComplexDeinterleavingOperation::CMulPartial); auto *Op0 = cast(I->getOperand(0)); auto *Op1 = cast(I->getOperand(1)); @@ -629,25 +684,19 @@ if (I->getOpcode() == Instruction::FSub) { if (isa(Use.getUser()) && Use.getOperandNo() != 0) { LLVM_DEBUG(dbgs() - << "First converging shuffle operand should be an FSub" - << ".\n"); - ContinueIdentification = false; + << "First converging shuffle operand should be an FSub.\n"); return false; } } else if (I->getOpcode() == Instruction::FAdd) { if (isa(Use.getUser()) && Use.getOperandNo() != 1) { LLVM_DEBUG(dbgs() - << "Second converging shuffle operand should be an FAdd" - << ".\n"); + << "Second converging shuffle operand should be an FAdd.\n"); return false; } } } - auto Pattern = m_BinOp(m_Shuffle(m_Value(), m_Value()), - m_Shuffle(m_Value(), m_Value())); - CN->IsTopLevel = match(CN->OriginalInput0, Pattern) && - match(CN->OriginalInput1, Pattern); + CN->UsesNegation = ContainsNeg; CN->OutputNode = I; @@ -728,10 +777,7 @@ } } - auto Pattern = m_BinOp(m_Shuffle(m_Value(), m_Value()), - m_Shuffle(m_Value(), m_Value())); - CN->IsTopLevel = match(CN->OriginalInput0, Pattern) && - match(CN->OriginalInput1, Pattern); + CN->OutputNode = J; submitCompositeNode(CN); return true; @@ -926,10 +972,76 @@ return true; } +void ComplexDeinterleavingGraph::splitNodes(const TargetLowering *TL) { + unsigned Cap = CompositeNodes.size(); + for (unsigned Idx = 0; Idx < Cap; Idx++) { + auto Item = CompositeNodes[Idx]; + auto NewNode = cloneCompositeNode(Item); + + Item->SetSplit(0); + NewNode->SetSplit(1); + + submitCompositeNode(NewNode); + } +} + +Value *ComplexDeinterleavingGraph::splitLoadIfNecessary( + ComplexDeinterleavingGraph::NodePtr Node, Value *V) { + if (!V) + return V; + + auto *I = dyn_cast(V); + if (!I) + return V; + + if (!isa(I) || !Node->HasSplit) + return I; + + IRBuilder<> B(I); + + auto *Ty = I->getType(); + auto *VTy = dyn_cast(Ty); + if (!VTy) + return I; + + unsigned Size = VTy->getNumElements() / 2; + SmallVector Mask = createArrayWithStep(Size, 1, Size * Node->SplitIdx); + + Value *Shuffle = nullptr; + + auto It = ShuffleMapping.find(I); + + if (It != ShuffleMapping.end()) { + Shuffle = (*It).second[Node->SplitIdx]; + } + + if (Shuffle == nullptr) { + Shuffle = B.CreateShuffleVector(I, Mask); + cast(Shuffle)->moveAfter(I); + LLVM_DEBUG(dbgs() << "Creating new shuffle:"; Shuffle->dump()); + + if (It == ShuffleMapping.end()) { + auto P = ShuffleMapping.emplace(I, new Value *[2]); + It = P.first; + (*It).second[0] = nullptr; + (*It).second[1] = nullptr; + } + (*It).second[Node->SplitIdx] = Shuffle; + } else { + + LLVM_DEBUG(dbgs() << "Reusing shuffle:"; Shuffle->dump()); + } + + return Shuffle; +} + bool ComplexDeinterleavingGraph::replaceNodes(const TargetLowering *TL) { if (CompositeNodes.empty()) return false; + if (NeedsSplit) + splitNodes(TL); + unsigned GeneratedIntrinsics = 0; auto *ConvergingI = Instructions[0]; @@ -939,22 +1051,24 @@ auto *N = cast(CN->OutputNode); // Wrangle the inputs - /// If the given value is part of a CompositeNode, and said node is part of /// an accumulator chain, return the accumulator. Otherwise, returns the /// "best fit" value (the ReplacementNode of a containing CompositeNode, or /// the value itself) - auto FollowAccumulatorIfNecessary = [&](Value *V) -> Value * { + auto FollowAccumulatorIfNecessary = [&](NodePtr Node, Value *V) -> Value * { + LLVM_DEBUG(dbgs() << "FollowAccumulatorIfNecessary" + << ".\n"); auto *I = dyn_cast(V); if (!I) return V; - auto CN = getContainingComposite(I); + auto CN = getContainingComposite(I, Node->SplitIdx); if (!CN) return I; if (CN->Accumulatee) - CN = getContainingComposite(cast(CN->Accumulatee)); + CN = getContainingComposite(cast(CN->Accumulatee), + CN->SplitIdx); return CN->ReplacementNode; }; @@ -962,21 +1076,18 @@ /// Given a value and an operand index, get said operand and return it. /// If the discovered operand is part of a composite node, return the /// replacement instead. - auto GetInputFromOriginalInput = [&](Value *OriginalInput, + auto GetInputFromOriginalInput = [&](NodePtr Node, Value *OriginalInput, unsigned OpIdx) -> Value * { auto *OriginalI = cast(OriginalInput); if (OriginalI->getOpcode() == Instruction::FNeg) OpIdx = 0; - auto *Op = OriginalI->getOperand(OpIdx); if (auto *SVI = dyn_cast(Op)) Op = SVI->getOperand(0); - if (!Op) return nullptr; - if (auto *I = dyn_cast(Op)) { - if (auto Containing = getContainingComposite(I)) { + if (auto Containing = getContainingComposite(I, Node->SplitIdx)) { if (Containing->ReplacementNode) return Containing->ReplacementNode; } @@ -1000,34 +1111,39 @@ if (!Sub) return false; - CN->Input0 = - FollowAccumulatorIfNecessary(GetInputFromOriginalInput(Sub, 0)); - CN->Input1 = - FollowAccumulatorIfNecessary(GetInputFromOriginalInput(Sub, 1)); + CN->Input0 = FollowAccumulatorIfNecessary( + CN, GetInputFromOriginalInput(CN, Sub, 0)); + CN->Input1 = FollowAccumulatorIfNecessary( + CN, GetInputFromOriginalInput(CN, Sub, 1)); } else { CN->Input0 = FollowAccumulatorIfNecessary( - GetInputFromOriginalInput(CN->OriginalInput0, 0)); + CN, GetInputFromOriginalInput(CN, CN->OriginalInput0, 0)); CN->Input1 = FollowAccumulatorIfNecessary( - GetInputFromOriginalInput(CN->OriginalInput1, 0)); + CN, GetInputFromOriginalInput(CN, CN->OriginalInput1, 0)); if (CN->OriginalInput0 != CN->OriginalInput1 && CN->Input0 == CN->Input1) CN->Input1 = FollowAccumulatorIfNecessary( - GetInputFromOriginalInput(CN->OriginalInput1, 1)); + CN, GetInputFromOriginalInput(CN, CN->OriginalInput1, 1)); } if (CN->Input0 == nullptr || CN->Input1 == nullptr) continue; + LLVM_DEBUG(dbgs() << "Splitting loads if necessary" + << ".\n"); + CN->Input0 = splitLoadIfNecessary(CN, CN->Input0); + CN->Input1 = splitLoadIfNecessary(CN, CN->Input1); + if (CN->Accumulator) { - if (auto Node = - getContainingComposite(cast(CN->Accumulator))) + if (auto Node = getContainingComposite(cast(CN->Accumulator), + CN->SplitIdx)) CN->Accumulator = cast(Node->ReplacementNode); } if (CN->Operation == llvm::ComplexDeinterleavingOperation::CMulPartial && CN->Accumulator) { - if (auto Node = - getContainingComposite(cast(CN->Accumulator))) { + if (auto Node = getContainingComposite(cast(CN->Accumulator), + CN->SplitIdx)) { bool Valid90 = (Node->Rotation == 0 && CN->Rotation == 90) || (Node->Rotation == 90 && CN->Rotation == 0); bool Valid270 = (Node->Rotation == 180 && CN->Rotation == 270) || @@ -1045,10 +1161,7 @@ CN->ReplacementNode = TL->createComplexDeinterleavingIR( N, CN->Operation, CN->Rotation, CN->Input0, CN->Input1, CN->Accumulator); - if (!CN->ReplacementNode) { - LLVM_DEBUG(dbgs() << "Target failed to create Intrinsic call.\n"); - return false; - } + assert(CN->ReplacementNode || "Target failed to create Intrinsic call."); cast(CN->ReplacementNode) ->moveAfter(cast(CN->OutputNode)); @@ -1058,10 +1171,24 @@ GeneratedIntrinsics += 1; } - auto *R = getFinalInputReplacement(ConvergingI); - if (!R) { - LLVM_DEBUG(dbgs() << "Unable to find Final Input Replacement.\n"); - return false; + Value *R = nullptr; + if (NeedsSplit) { + auto *R0 = getFinalInputReplacement(ConvergingI, 0); + auto *R1 = getFinalInputReplacement(ConvergingI, 1); + + unsigned Size = 8; + + auto Mask = createArrayWithStep(Size, 1); + + IRBuilder<> B(ConvergingI); + R = B.CreateShuffleVector(R0, R1, Mask); + + } else { + auto *R = getFinalInputReplacement(ConvergingI); + if (!R) { + LLVM_DEBUG(dbgs() << "Unable to find Final Input Replacement.\n"); + return false; + } } InstructionCost CostOfNodes;