diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -130,7 +130,7 @@ struct ComplexDeinterleavingCompositeNode { ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op, - Instruction *R, Instruction *I) + Value *R, Value *I) : Operation(Op), Real(R), Imag(I) {} private: @@ -140,8 +140,8 @@ public: ComplexDeinterleavingOperation Operation; - Instruction *Real; - Instruction *Imag; + Value *Real; + Value *Imag; // This two members are required exclusively for generating // ComplexDeinterleavingOperation::Symmetric operations. @@ -192,19 +192,19 @@ class ComplexDeinterleavingGraph { public: struct Product { - Instruction *Multiplier; - Instruction *Multiplicand; + Value *Multiplier; + Value *Multiplicand; bool IsPositive; }; - using Addend = std::pair; + using Addend = std::pair; using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr; using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr; // Helper struct for holding info about potential partial multiplication // candidates struct PartialMulCandidate { - Instruction *Common; + Value *Common; NodePtr Node; unsigned RealIdx; unsigned ImagIdx; @@ -270,7 +270,7 @@ std::map OldToNewPHI; NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, - Instruction *R, Instruction *I) { + Value *R, Value *I) { assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && Operation != ComplexDeinterleavingOperation::ReductionOperation) || (R && I)) && @@ -308,9 +308,9 @@ /// Identify the other branch of a Partial Mul, taking the CommonOperandI that /// is partially known from identifyPartialMul, filling in the other half of /// the complex pair. - NodePtr identifyNodeWithImplicitAdd( - Instruction *I, Instruction *J, - std::pair &CommonOperandI); + NodePtr + identifyNodeWithImplicitAdd(Instruction *I, Instruction *J, + std::pair &CommonOperandI); /// Identifies a complex add pattern and its rotation, based on the following /// patterns. @@ -322,7 +322,7 @@ NodePtr identifyAdd(Instruction *Real, Instruction *Imag); NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag); - NodePtr identifyNode(Instruction *I, Instruction *J); + NodePtr identifyNode(Value *R, Value *I); /// Determine if a sum of complex numbers can be formed from \p RealAddends /// and \p ImagAddens. If \p Accumulator is not null, add the result to it. @@ -521,7 +521,7 @@ ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd( Instruction *Real, Instruction *Imag, - std::pair &PartialMatch) { + std::pair &PartialMatch) { LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag << "\n"); @@ -536,52 +536,43 @@ return nullptr; } - Instruction *R0 = dyn_cast(Real->getOperand(0)); - Instruction *R1 = dyn_cast(Real->getOperand(1)); - Instruction *I0 = dyn_cast(Imag->getOperand(0)); - Instruction *I1 = dyn_cast(Imag->getOperand(1)); - if (!R0 || !R1 || !I0 || !I1) { - LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); - return nullptr; - } + Value *R0 = Real->getOperand(0); + Value *R1 = Real->getOperand(1); + Value *I0 = Imag->getOperand(0); + Value *I1 = Imag->getOperand(1); // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the // rotations and use the operand. unsigned Negs = 0; SmallVector FNegs; - if (R0->getOpcode() == Instruction::FNeg || - R1->getOpcode() == Instruction::FNeg) { + Value *Op; + if (match(R0, m_Neg(m_Value(Op)))) { Negs |= 1; - if (R0->getOpcode() == Instruction::FNeg) { - FNegs.push_back(R0); - R0 = dyn_cast(R0->getOperand(0)); - } else { - FNegs.push_back(R1); - R1 = dyn_cast(R1->getOperand(0)); - } - if (!R0 || !R1) - return nullptr; + FNegs.push_back(cast(R0)); + R0 = Op; + } else if (match(R1, m_Neg(m_Value(Op)))) { + Negs |= 1; + FNegs.push_back(cast(R1)); + R1 = Op; } - if (I0->getOpcode() == Instruction::FNeg || - I1->getOpcode() == Instruction::FNeg) { + + if (match(I0, m_Neg(m_Value(Op)))) { Negs |= 2; Negs ^= 1; - if (I0->getOpcode() == Instruction::FNeg) { - FNegs.push_back(I0); - I0 = dyn_cast(I0->getOperand(0)); - } else { - FNegs.push_back(I1); - I1 = dyn_cast(I1->getOperand(0)); - } - if (!I0 || !I1) - return nullptr; + FNegs.push_back(cast(I0)); + I0 = Op; + } else if (match(I1, m_Neg(m_Value(Op)))) { + Negs |= 2; + Negs ^= 1; + FNegs.push_back(cast(I0)); + I1 = Op; } ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs; - Instruction *CommonOperand; - Instruction *UncommonRealOp; - Instruction *UncommonImagOp; + Value *CommonOperand; + Value *UncommonRealOp; + Value *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; @@ -676,18 +667,14 @@ return nullptr; } - Instruction *R0 = dyn_cast(RealMulI->getOperand(0)); - Instruction *R1 = dyn_cast(RealMulI->getOperand(1)); - Instruction *I0 = dyn_cast(ImagMulI->getOperand(0)); - Instruction *I1 = dyn_cast(ImagMulI->getOperand(1)); - if (!R0 || !R1 || !I0 || !I1) { - LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n"); - return nullptr; - } + Value *R0 = RealMulI->getOperand(0); + Value *R1 = RealMulI->getOperand(1); + Value *I0 = ImagMulI->getOperand(0); + Value *I1 = ImagMulI->getOperand(1); - Instruction *CommonOperand; - Instruction *UncommonRealOp; - Instruction *UncommonImagOp; + Value *CommonOperand; + Value *UncommonRealOp; + Value *UncommonImagOp; if (R0 == I0 || R0 == I1) { CommonOperand = R0; @@ -705,7 +692,7 @@ Rotation == ComplexDeinterleavingRotation::Rotation_270) std::swap(UncommonRealOp, UncommonImagOp); - std::pair PartialMatch( + std::pair PartialMatch( (Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180) ? CommonOperand @@ -840,11 +827,8 @@ !isInstructionPotentiallySymmetric(Imag)) return nullptr; - auto *R0 = dyn_cast(Real->getOperand(0)); - auto *I0 = dyn_cast(Imag->getOperand(0)); - - if (!R0 || !I0) - return nullptr; + auto *R0 = Real->getOperand(0); + auto *I0 = Imag->getOperand(0); NodePtr Op0 = identifyNode(R0, I0); NodePtr Op1 = nullptr; @@ -852,11 +836,8 @@ return nullptr; if (Real->isBinaryOp()) { - auto *R1 = dyn_cast(Real->getOperand(1)); - auto *I1 = dyn_cast(Imag->getOperand(1)); - if (!R1 || !I1) - return nullptr; - + auto *R1 = Real->getOperand(1); + auto *I1 = Imag->getOperand(1); Op1 = identifyNode(R1, I1); if (Op1 == nullptr) return nullptr; @@ -880,13 +861,18 @@ } ComplexDeinterleavingGraph::NodePtr -ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { - LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n"); - if (NodePtr CN = getContainingComposite(Real, Imag)) { +ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) { + LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n"); + if (NodePtr CN = getContainingComposite(R, I)) { LLVM_DEBUG(dbgs() << " - Folding to existing node\n"); return CN; } + auto *Real = dyn_cast(R); + auto *Imag = dyn_cast(I); + if (!Real || !Imag) + return nullptr; + if (NodePtr CN = identifyDeinterleave(Real, Imag)) return CN; @@ -929,6 +915,7 @@ ComplexDeinterleavingGraph::NodePtr ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real, Instruction *Imag) { + if ((Real->getOpcode() != Instruction::FAdd && Real->getOpcode() != Instruction::FSub && Real->getOpcode() != Instruction::FNeg) || @@ -965,8 +952,10 @@ continue; Instruction *I = dyn_cast(V); - if (!I) - return false; + if (!I) { + Addends.emplace_back(V, IsPositive); + continue; + } // If an instruction has more than one user, it indicates that it either // has an external user, which will be later checked by the checkNodes @@ -987,20 +976,18 @@ Worklist.emplace_back(I->getOperand(1), !IsPositive); Worklist.emplace_back(I->getOperand(0), IsPositive); } else if (I->getOpcode() == Instruction::FMul) { - auto *A = dyn_cast(I->getOperand(0)); - if (A && A->getOpcode() == Instruction::FNeg) { - A = dyn_cast(A->getOperand(0)); + Value *A, *B; + if (match(I->getOperand(0), m_FNeg(m_Value(A)))) { IsPositive = !IsPositive; + } else { + A = I->getOperand(0); } - if (!A) - return false; - auto *B = dyn_cast(I->getOperand(1)); - if (B && B->getOpcode() == Instruction::FNeg) { - B = dyn_cast(B->getOperand(0)); + + if (match(I->getOperand(1), m_FNeg(m_Value(B)))) { IsPositive = !IsPositive; + } else { + B = I->getOperand(1); } - if (!B) - return false; Muls.push_back(Product{A, B, IsPositive}); } else if (I->getOpcode() == Instruction::FNeg) { Worklist.emplace_back(I->getOperand(0), !IsPositive); @@ -1057,7 +1044,7 @@ std::vector &PartialMulCandidates) { // Helper function to extract a common operand from two products auto FindCommonInstruction = [](const Product &Real, - const Product &Imag) -> Instruction * { + const Product &Imag) -> Value * { if (Real.Multiplicand == Imag.Multiplicand || Real.Multiplicand == Imag.Multiplier) return Real.Multiplicand; @@ -1085,18 +1072,17 @@ auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier : ImagMuls[j].Multiplicand; - bool Inverted = false; auto Node = identifyNode(A, B); - if (!Node) { - std::swap(A, B); - Inverted = true; - Node = identifyNode(A, B); + if (Node) { + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, false}); } - if (!Node) - continue; - FoundCommon = true; - PartialMulCandidates.push_back({Common, Node, i, j, Inverted}); + Node = identifyNode(B, A); + if (Node) { + FoundCommon = true; + PartialMulCandidates.push_back({Common, Node, i, j, true}); + } } if (!FoundCommon) return false; @@ -1116,7 +1102,7 @@ return nullptr; // Map to store common instruction to node pointers - std::map CommonToNode; + std::map CommonToNode; std::vector Processed(Info.size(), false); for (unsigned I = 0; I < Info.size(); ++I) { if (Processed[I]) @@ -1834,8 +1820,8 @@ processReductionOperation(ReplacementNode, Node); break; case ComplexDeinterleavingOperation::ReductionSelect: { - auto *MaskReal = Node->Real->getOperand(0); - auto *MaskImag = Node->Imag->getOperand(0); + auto *MaskReal = cast(Node->Real)->getOperand(0); + auto *MaskImag = cast(Node->Imag)->getOperand(0); auto *A = replaceNode(Builder, Node->Operands[0]); auto *B = replaceNode(Builder, Node->Operands[1]); auto *NewMaskTy = VectorType::getDoubleElementsVectorType( @@ -1860,11 +1846,13 @@ void ComplexDeinterleavingGraph::processReductionOperation( Value *OperationReplacement, RawNodePtr Node) { - auto *OldPHIReal = ReductionInfo[Node->Real].first; - auto *OldPHIImag = ReductionInfo[Node->Imag].first; + auto *Real = cast(Node->Real); + auto *Imag = cast(Node->Imag); + auto *OldPHIReal = ReductionInfo[Real].first; + auto *OldPHIImag = ReductionInfo[Imag].first; auto *NewPHI = OldToNewPHI[OldPHIReal]; - auto *VTy = cast(Node->Real->getType()); + auto *VTy = cast(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); // We have to interleave initial origin values coming from IncomingBlock @@ -1880,8 +1868,8 @@ // Deinterleave complex vector outside of loop so that it can be finally // reduced - auto *FinalReductionReal = ReductionInfo[Node->Real].second; - auto *FinalReductionImag = ReductionInfo[Node->Imag].second; + auto *FinalReductionReal = ReductionInfo[Real].second; + auto *FinalReductionImag = ReductionInfo[Imag].second; Builder.SetInsertPoint( &*FinalReductionReal->getParent()->getFirstInsertionPt()); @@ -1890,11 +1878,11 @@ OperationReplacement->getType(), OperationReplacement); auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); - FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal); + FinalReductionReal->replaceUsesOfWith(Real, NewReal); Builder.SetInsertPoint(FinalReductionImag); auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); - FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag); + FinalReductionImag->replaceUsesOfWith(Imag, NewImag); } void ComplexDeinterleavingGraph::replaceNodes() { @@ -1911,10 +1899,12 @@ if (RootNode->Operation == ComplexDeinterleavingOperation::ReductionOperation) { - ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge); - ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge); - DeadInstrRoots.push_back(RootNode->Real); - DeadInstrRoots.push_back(RootNode->Imag); + auto *RootReal = cast(RootNode->Real); + auto *RootImag = cast(RootNode->Imag); + ReductionInfo[RootReal].first->removeIncomingValue(BackEdge); + ReductionInfo[RootImag].first->removeIncomingValue(BackEdge); + DeadInstrRoots.push_back(cast(RootReal)); + DeadInstrRoots.push_back(cast(RootImag)); } else { assert(R && "Unable to find replacement for RootInstruction"); DeadInstrRoots.push_back(RootInstruction);