diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h --- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -48,6 +48,9 @@ Value *Op; ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} + bool operator ==(const ValueEntry &Other) { + return Op == Other.Op && Rank == Other.Rank; + } }; inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -2274,49 +2274,91 @@ return; } + SmallVector, 4> ChosenPairs; if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) { - // Find the pair with the highest count in the pairmap and move it to the - // back of the list so that it can later be CSE'd. + // Find the pairs with more than one use in the pairmap and move them to the + // back of the list so that they can later be CSE'd. // example: // a*b*c*d*e - // if c*e is the most "popular" pair, we can express this as - // (((c*e)*d)*b)*a - unsigned Max = 1; - unsigned BestRank = 0; - std::pair BestPair; - unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; - for (unsigned i = 0; i < Ops.size() - 1; ++i) - for (unsigned j = i + 1; j < Ops.size(); ++j) { - unsigned Score = 0; - Value *Op0 = Ops[i].Op; - Value *Op1 = Ops[j].Op; - if (std::less()(Op1, Op0)) - std::swap(Op0, Op1); - auto it = PairMap[Idx].find({Op0, Op1}); - if (it != PairMap[Idx].end()) { + // if c*e and b*d have more than one use, we can express this as + // (((c*e)*(b*d))*a) + SmallVector OpsWorkList(Ops); + while (OpsWorkList.size() >= 2) { + unsigned Max = 1; + unsigned BestRank = 0; + std::pair BestPair; + unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; + for (unsigned i = 0; i < OpsWorkList.size() - 1; ++i) + for (unsigned j = i + 1; j < OpsWorkList.size(); ++j) { + unsigned Score = 0; + Value *Op0 = OpsWorkList[i].Op; + Value *Op1 = OpsWorkList[j].Op; + if (std::less()(Op1, Op0)) + std::swap(Op0, Op1); + auto it = PairMap[Idx].find({Op0, Op1}); + if (it == PairMap[Idx].end()) + continue; // Functions like BreakUpSubtract() can erase the Values we're using // as keys and create new Values after we built the PairMap. There's a // small chance that the new nodes can have the same address as // something already in the table. We shouldn't accumulate the stored // score in that case as it refers to the wrong Value. if (it->second.isValid()) - Score += it->second.Score; - } + Score = it->second.Score; - unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); - if (Score > Max || (Score == Max && MaxRank < BestRank)) { - BestPair = {i, j}; - Max = Score; - BestRank = MaxRank; + unsigned MaxRank = std::max(OpsWorkList[i].Rank, OpsWorkList[j].Rank); + if (Score > Max || (Score == Max && MaxRank < BestRank)) { + BestPair = {i, j}; + Max = Score; + BestRank = MaxRank; + } } + if (Max == 1) + break; + if (Max > 1) { + auto Op0InWorkList = OpsWorkList[BestPair.first]; + auto Op1InWorkList = OpsWorkList[BestPair.second]; + + // Update OpsWorkList. + OpsWorkList.erase(&OpsWorkList[BestPair.second]); + OpsWorkList.erase(&OpsWorkList[BestPair.first]); + + Value *Op0 = Op0InWorkList.Op; + Value *Op1 = Op1InWorkList.Op; + if (std::less()(Op1, Op0)) + std::swap(Op0, Op1); + + // If current best pair is already chosen, skip. + if (std::find(ChosenPairs.begin(), ChosenPairs.end(), + std::make_pair(Op0, Op1)) != ChosenPairs.end()) + continue; + + // Get one best pair, put it into back of Ops. + // ChosenPairs.size() * 2 elements at tail position are selected pairs, + // do not touch them. + auto Op0InOpsIter = std::find( + Ops.begin(), Ops.end() - ChosenPairs.size() * 2, Op0InWorkList); + auto Op0InOps = *Op0InOpsIter; + assert((Op0InOpsIter != (Ops.end() - ChosenPairs.size() * 2)) && + "ops wrong\n"); + Ops.erase(Op0InOpsIter); + + auto Op1InOpsIter = std::find( + Ops.begin(), Ops.end() - ChosenPairs.size() * 2, Op1InWorkList); + assert((Op1InOpsIter != (Ops.end() - ChosenPairs.size() * 2)) && + "ops wrong\n"); + auto Op1InOps = *Op1InOpsIter; + Ops.erase(Op1InOpsIter); + + // FIXME: for now we want a NFC patch which keeps most 'popular' pair at + // the very end like before. But we should use push_back instead of + // insert bacause all pairs should be equal. + Ops.insert(Ops.end() - ChosenPairs.size() * 2, Op0InOps); + Ops.insert(Ops.end() - ChosenPairs.size() * 2, Op1InOps); + + // Update ChosenPairs. + ChosenPairs.push_back(std::make_pair(Op0, Op1)); } - if (Max > 1) { - auto Op0 = Ops[BestPair.first]; - auto Op1 = Ops[BestPair.second]; - Ops.erase(&Ops[BestPair.second]); - Ops.erase(&Ops[BestPair.first]); - Ops.push_back(Op0); - Ops.push_back(Op1); } } // Now that we ordered and optimized the expressions, splat them back into