diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -282,8 +282,12 @@ m_Value(LHS), m_Value(RHS)); if (match(I, MinMaxMatcher)) { OrigSCEV = SE->getSCEV(I); - return dyn_cast_or_null( - tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS)); + if (auto *NewMinMax = dyn_cast_or_null( + tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS))) + return NewMinMax; + if (auto *NewMinMax = dyn_cast_or_null( + tryReassociateMinOrMax(I, MinMaxMatcher, RHS, LHS))) + return NewMinMax; } return nullptr; } @@ -596,61 +600,60 @@ Value *LHS, Value *RHS) { Value *A = nullptr, *B = nullptr; MaxMinT m_MaxMin(m_Value(A), m_Value(B)); - for (unsigned int i = 0; i < 2; ++i) { - if (!LHS->hasNUsesOrMore(3) && match(LHS, m_MaxMin)) { - Value *C = RHS; - const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); - const SCEV *CExpr = SE->getSCEV(C); - for (unsigned int j = 0; j < 2; ++j) { - if (j == 0) { - if (BExpr == CExpr) - continue; - // Transform 'I = (A op B) op C' to 'I = (A op C) op B' on the - // first iteration. - std::swap(BExpr, CExpr); - std::swap(B, C); - } else { - if (AExpr == CExpr) - continue; - // Transform 'I = (A op C) op B' to 'I = (B op C) op A' on the second - // iteration. - std::swap(AExpr, CExpr); - std::swap(A, C); - } - - // The optimization is profitable only if LHS can be removed in the end. - // In other words LHS should be used (directly or indirectly) by I only. - if (llvm::any_of(LHS->users(), [&](auto *U) { - return U != I && !(U->hasOneUser() && *U->users().begin() == I); - })) - continue; - - SCEVExpander Expander(*SE, *DL, "nary-reassociate"); - SmallVector Ops1{ BExpr, AExpr }; - const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin); - const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1); - - Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I); - - if (!R1MinMax) - continue; - - LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax - << "\n"); - - SmallVector Ops2{SE->getUnknown(C), - SE->getUnknown(R1MinMax)}; - const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2); - - Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I); - NewMinMax->setName(Twine(I->getName()).concat(".nary")); - - LLVM_DEBUG(dbgs() << "NARY: Deleting: " << *I << "\n" - << "NARY: Inserting: " << *NewMinMax << "\n"); - return NewMinMax; - } - } - std::swap(LHS, RHS); + + if (LHS->hasNUsesOrMore(3) || + // The optimization is profitable only if LHS can be removed in the end. + // In other words LHS should be used (directly or indirectly) by I only. + llvm::any_of(LHS->users(), + [&](auto *U) { + return U != I && + !(U->hasOneUser() && *U->users().begin() == I); + }) || + !match(LHS, m_MaxMin)) + return nullptr; + + auto tryCombination = [&](Value *A, const SCEV *AExpr, Value *B, + const SCEV *BExpr, Value *C, + const SCEV *CExpr) -> Value * { + SmallVector Ops1{BExpr, AExpr}; + const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin); + const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1); + + Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I); + + if (!R1MinMax) + return nullptr; + + LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n"); + + SmallVector Ops2{SE->getUnknown(C), + SE->getUnknown(R1MinMax)}; + const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2); + + SCEVExpander Expander(*SE, *DL, "nary-reassociate"); + Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I); + NewMinMax->setName(Twine(I->getName()).concat(".nary")); + + LLVM_DEBUG(dbgs() << "NARY: Deleting: " << *I << "\n" + << "NARY: Inserting: " << *NewMinMax << "\n"); + return NewMinMax; + }; + + const SCEV *AExpr = SE->getSCEV(A); + const SCEV *BExpr = SE->getSCEV(B); + const SCEV *RHSExpr = SE->getSCEV(RHS); + + if (BExpr != RHSExpr) { + // Try (A op RHS) op B + if (auto *NewMinMax = tryCombination(A, AExpr, RHS, RHSExpr, B, BExpr)) + return NewMinMax; + } + + if (AExpr != RHSExpr) { + // Try (RHS op B) op A + if (auto *NewMinMax = tryCombination(RHS, RHSExpr, B, BExpr, A, AExpr)) + return NewMinMax; } + return nullptr; }