Index: lib/Target/ARM/ARMParallelDSP.cpp =================================================================== --- lib/Target/ARM/ARMParallelDSP.cpp +++ lib/Target/ARM/ARMParallelDSP.cpp @@ -39,31 +39,54 @@ #define DEBUG_TYPE "parallel-dsp" namespace { + struct OpChain; struct BinOpChain; struct Reduction; - using BinOpChainList = SmallVector; + using OpChainList = SmallVector; using ReductionList = SmallVector; - using ValueList = SmallVector; + using ValueList = SmallSetVector; using MemInstList = SmallVector; using PMACPair = std::pair; using PMACPairList = SmallVector; using Instructions = SmallVector; using MemLocList = SmallVector; + struct OpChain { + Instruction *Root; + ValueList AllValues; + MemInstList VecLd; // List of all load instructions. + MemLocList MemLocs; // All memory locations read by this tree. + OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { } + + void SetMemoryLocations() { + const auto Size = MemoryLocation::UnknownSize; + for (auto V : AllValues) + if (auto *Ld = dyn_cast(V)) + MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size)); + } + + bool contains(Value *V) { + if (V == Root) + return true; + return AllValues.count(V) != 0; + } + + unsigned size() const { return AllValues.size(); } + }; + // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures. // 'Reduction' contains the phi-node and accumulator statement from where we // start pattern matching, and 'BinOpChain' the multiplication // instructions that are candidates for parallel execution. - struct BinOpChain { - Instruction *Root; + struct BinOpChain : public OpChain { ValueList LHS; // List of all (narrow) left hand operands. ValueList RHS; // List of all (narrow) right hand operands. - MemInstList VecLd; // List of all load instructions. - MemLocList MemLocs; // All memory locations read by this tree. BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) : - Root(I), LHS(lhs), RHS(rhs) {}; + OpChain(I, lhs), LHS(lhs), RHS(rhs) { + AllValues.insert(RHS.begin(), RHS.end()); + } }; struct Reduction { @@ -87,7 +110,7 @@ bool InsertParallelMACs(Reduction &Reduction, PMACPairList &PMACPairs); bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem); - PMACPairList CreateParallelMACPairs(BinOpChainList &Candidates); + PMACPairList CreateParallelMACPairs(OpChainList &Candidates); Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1, Instruction *Acc, Instruction *InsertAfter); @@ -204,8 +227,8 @@ if (match(Val, m_Load(m_Value()))) { LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump()); - VL.push_back(Val); - VL.push_back(I); + VL.insert(Val); + VL.insert(I); return true; } } @@ -290,7 +313,7 @@ } PMACPairList -ARMParallelDSP::CreateParallelMACPairs(BinOpChainList &Candidates) { +ARMParallelDSP::CreateParallelMACPairs(OpChainList &Candidates) { const unsigned Elems = Candidates.size(); PMACPairList PMACPairs; @@ -301,10 +324,10 @@ // We can compare all elements, but then we need to compare and evaluate // different solutions. for(unsigned i=0; i(Candidates[i]); + BinOpChain *PMul1 = static_cast(Candidates[i+1]); + const Instruction *Mul0 = PMul0->Root; + const Instruction *Mul1 = PMul1->Root; if (Mul0 == Mul1) continue; @@ -313,10 +336,10 @@ dbgs() << "- "; Mul0->dump(); dbgs() << "- "; Mul1->dump()); - const ValueList &Mul0_LHS = PMul0.LHS; - const ValueList &Mul0_RHS = PMul0.RHS; - const ValueList &Mul1_LHS = PMul1.LHS; - const ValueList &Mul1_RHS = PMul1.RHS; + const ValueList &Mul0_LHS = PMul0->LHS; + const ValueList &Mul0_RHS = PMul0->RHS; + const ValueList &Mul1_LHS = PMul1->LHS; + const ValueList &Mul1_RHS = PMul1->RHS; if (!AreSymmetrical(Mul0_LHS, Mul1_LHS) || !AreSymmetrical(Mul0_RHS, Mul1_RHS)) @@ -340,10 +363,10 @@ dbgs() << "\t mul1: "; Mul0_RHS[x]->dump(); dbgs() << "\t mul2: "; Mul1_RHS[x]->dump()); - if (AreSequentialLoads(Ld0, Ld1, Candidates[i].VecLd) && - AreSequentialLoads(Ld2, Ld3, Candidates[i+1].VecLd)) { + if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd) && + AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); - PMACPairs.push_back(std::make_pair(&PMul0, &PMul1)); + PMACPairs.push_back(std::make_pair(PMul0, PMul1)); } } } @@ -412,7 +435,7 @@ ); } -static void AddCandidateMAC(BinOpChainList &Candidates, +static void AddCandidateMAC(OpChainList &Candidates, const Instruction *Acc, Value *MulOp0, Value *MulOp1, int MulOpNum) { Instruction *Mul = dyn_cast(Acc->getOperand(MulOpNum)); @@ -422,12 +445,12 @@ if (IsNarrowSequence<16>(MulOp0, LHS) && IsNarrowSequence<16>(MulOp1, RHS)) { LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump()); - Candidates.push_back(BinOpChain(Mul, LHS, RHS)); + Candidates.push_back(new BinOpChain(Mul, LHS, RHS)); } } static void MatchParallelMACSequences(Reduction &R, - BinOpChainList &Candidates) { + OpChainList &Candidates) { const Instruction *Acc = R.AccIntAdd; Value *A, *MulOp0, *MulOp1; LLVM_DEBUG(dbgs() << "\n- Analysing:\t"; Acc->dump()); @@ -458,20 +481,12 @@ // Collects all instructions that are not part of the BinOp chains, which is // the set of instructions that can potentially alias with the MAC operands. static void AliasCandidates(BasicBlock *Header, - BinOpChainList &Candidates, + OpChainList &Candidates, Instructions &Aliases) { - auto IsCandidate = [] (Instruction *I, BinOpChainList &Candidates) { - for (auto &C : Candidates) { - if (I == C.Root) + auto IsCandidate = [] (Instruction *I, OpChainList &Candidates) { + for (auto *C : Candidates) { + if (C->contains(I)) return true; - for (auto *Val : C.LHS) { - if (Val == I) - return true; - } - for (auto *Val : C.RHS) { - if (Val == I) - return true; - } } return false; }; @@ -487,14 +502,14 @@ // instructions that are not part of the MAC-chain, with all instructions in // the MAC candidate set, to see if instructions are aliased. static bool AreAliased(AliasAnalysis *AA, Instructions AliasCandidates, - BinOpChainList &Candidates) { + OpChainList &Candidates) { LLVM_DEBUG(dbgs() << "Alias checks:\n"); for (auto *I : AliasCandidates) { LLVM_DEBUG(dbgs() << "- "; I->dump()); - for (auto &C : Candidates) { - LLVM_DEBUG(dbgs() << "mul: "; C.Root->dump()); - assert(C.MemLocs.size() >= 2 && "expecting at least 2 memlocs"); - for (auto &MemLoc : C.MemLocs) { + for (auto *C : Candidates) { + LLVM_DEBUG(dbgs() << "mul: "; C->Root->dump()); + assert(C->MemLocs.size() >= 2 && "expecting at least 2 memlocs"); + for (auto &MemLoc : C->MemLocs) { if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc), ModRefInfo::ModRef))) { LLVM_DEBUG(dbgs() << "Yes, aliases found\n"); @@ -507,24 +522,21 @@ return false; } -static bool SetMemoryLocations(BinOpChainList &Candidates) { - const auto Size = MemoryLocation::UnknownSize; - for (auto &C : Candidates) { +static bool CheckMACMemory(OpChainList &Candidates) { + for (auto *C : Candidates) { // A mul has 2 operands, and a narrow op consist of sext and a load; thus // we expect at least 4 items in this operand value list. - if (C.LHS.size() < 2 || C.LHS.size() != C.RHS.size()) { + if (C->size() < 4) { LLVM_DEBUG(dbgs() << "Operand list too short.\n"); return false; } + C->SetMemoryLocations(); + ValueList &LHS = static_cast(C)->LHS; + ValueList &RHS = static_cast(C)->RHS; - for (unsigned i = 0, e = C.LHS.size(); i < e; i += 2) { - auto *LdOp0 = dyn_cast(C.LHS[i]); - auto *LdOp1 = dyn_cast(C.RHS[i]); - if (!LdOp0 || !LdOp1) + for (unsigned i = 0, e = LHS.size(); i < e; i += 2) { + if (!isa(LHS[i]) || !isa(RHS[i])) return false; - - C.MemLocs.push_back(MemoryLocation(LdOp0->getPointerOperand(), Size)); - C.MemLocs.push_back(MemoryLocation(LdOp1->getPointerOperand(), Size)); } } return true; @@ -575,9 +587,9 @@ MatchReductions(F, L, Header, Reductions); for (auto &R : Reductions) { - BinOpChainList MACCandidates; + OpChainList MACCandidates; MatchParallelMACSequences(R, MACCandidates); - if (!SetMemoryLocations(MACCandidates)) + if (!CheckMACMemory(MACCandidates)) continue; Instructions Aliases; AliasCandidates(Header, MACCandidates, Aliases);