diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -598,21 +598,24 @@ MaxMinT m_MaxMin(m_Value(A), m_Value(B)); for (unsigned int i = 0; i < 2; ++i) { if (!LHS->hasNUsesOrMore(3) && match(LHS, m_MaxMin)) { + Value *C = RHS; const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); - const SCEV *RHSExpr = SE->getSCEV(RHS); + const SCEV *CExpr = SE->getSCEV(C); for (unsigned int j = 0; j < 2; ++j) { if (j == 0) { - if (BExpr == RHSExpr) + if (BExpr == CExpr) continue; - // Transform 'I = (A op B) op RHS' to 'I = (A op RHS) op B' on the + // Transform 'I = (A op B) op C' to 'I = (A op C) op B' on the // first iteration. - std::swap(BExpr, RHSExpr); + std::swap(BExpr, CExpr); + std::swap(B, C); } else { - if (AExpr == RHSExpr) + if (AExpr == CExpr) continue; - // Transform 'I = (A op RHS) op B' 'I = (B op RHS) op A' on the second + // Transform 'I = (A op C) op B' to 'I = (B op C) op A' on the second // iteration. - std::swap(AExpr, RHSExpr); + std::swap(AExpr, CExpr); + std::swap(A, C); } // The optimization is profitable only if LHS can be removed in the end. @@ -635,8 +638,8 @@ LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n"); - R1Expr = SE->getUnknown(R1MinMax); - SmallVector Ops2{ RHSExpr, R1Expr }; + SmallVector Ops2{SE->getUnknown(C), + SE->getUnknown(R1MinMax)}; const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2); Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I); diff --git a/llvm/test/Transforms/NaryReassociate/nary-req.ll b/llvm/test/Transforms/NaryReassociate/nary-req.ll --- a/llvm/test/Transforms/NaryReassociate/nary-req.ll +++ b/llvm/test/Transforms/NaryReassociate/nary-req.ll @@ -3,7 +3,7 @@ ; RUN: opt < %s -passes='nary-reassociate' -S | FileCheck %s declare i32 @llvm.smax.i32(i32 %a, i32 %b) -declare i64 @llvm.umin.i64(i64, i64) +declare i64 @llvm.umin.i64(i64, i64) ; This is a negative test. We should not optimize if intermediate result ; has a use outside of optimizable pattern. In other words %smax2 has one @@ -46,7 +46,8 @@ ; CHECK-NEXT: [[E4:%.*]] = sub i64 [[ARG]], 0 ; CHECK-NEXT: [[E5:%.*]] = call i64 @llvm.umin.i64(i64 [[E4]], i64 16384) ; CHECK-NEXT: [[E6:%.*]] = icmp ugt i64 [[E5]], 0 -; CHECK-NEXT: [[E10_NARY:%.*]] = call i64 @llvm.umin.i64(i64 [[E5]], i64 [[E]]) +; CHECK-NEXT: [[E7:%.*]] = sub i64 undef, 0 +; CHECK-NEXT: [[E10_NARY:%.*]] = call i64 @llvm.umin.i64(i64 [[E5]], i64 [[E7]]) ; CHECK-NEXT: unreachable ; bb: @@ -64,3 +65,40 @@ unreachable } +; Make sure we don't fall into infinte loop optimizing %sel5. +; The subtle thing is that %sel3 is min/max as well and +; there is "unexpected" reassociation coming from SCEV Expander +; during %sel5 rewrite. That results in a new chain of min/max +; which is matched on the next iteration. +define i32 @nary_infinite_loop_minmax(i32 %d0, i32 %d1, i32 %d2, i32 %d3) { +; CHECK-LABEL: @nary_infinite_loop_minmax( +; CHECK-NEXT: [[CMP0:%.*]] = icmp slt i32 [[D2:%.*]], [[D1:%.*]] +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 [[D1]], i32 [[D2]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[D3:%.*]], [[D0:%.*]] +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 [[D0]], i32 [[D3]] +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[SEL1]], [[SEL0]] +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 [[SEL1]], i32 [[SEL0]] +; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 [[D3]], [[D0]] +; CHECK-NEXT: [[SEL3:%.*]] = select i1 [[CMP3]], i32 [[D0]], i32 [[D3]] +; CHECK-NEXT: [[SEL5_NARY:%.*]] = call i32 @llvm.smax.i32(i32 [[SEL0]], i32 [[SEL3]]) +; CHECK-NEXT: ret i32 [[SEL5_NARY]] +; + %cmp0 = icmp slt i32 %d2, %d1 + %sel0 = select i1 %cmp0, i32 %d1, i32 %d2 + + %cmp1 = icmp slt i32 %d3, %d0 + %sel1 = select i1 %cmp1, i32 %d0, i32 %d3 + + %cmp2 = icmp slt i32 %sel1, %sel0 + %sel2 = select i1 %cmp2, i32 %sel1, i32 %sel0 + + %cmp3 = icmp slt i32 %d3, %d0 + %sel3 = select i1 %cmp3, i32 %d0, i32 %d3 + + %cmp4 = icmp slt i32 %sel3, %d2 + %sel4 = select i1 %cmp4, i32 %d2, i32 %sel3 + + %cmp5 = icmp slt i32 %sel4, %d1 + %sel5 = select i1 %cmp5, i32 %d1, i32 %sel4 + ret i32 %sel5 +}