diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h --- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -105,7 +105,8 @@ void canonicalizeOperands(Instruction *I); void ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops); + SmallVectorImpl &Ops, + unsigned ComExprNum = 0); Value *OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops); Value *OptimizeAdd(Instruction *I, diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -633,7 +633,8 @@ /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. void ReassociatePass::RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops, + unsigned ComExprNum) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -669,11 +670,14 @@ // original in some non-trivial way, requiring the clearing of optional flags. // Flags are cleared from the operator in ExpressionChanged up to I inclusive. BinaryOperator *ExpressionChanged = nullptr; + unsigned TotalEleNum = Ops.size(); + SmallPtrSet RewriteExprs; + SmallVector PairResults; for (unsigned i = 0; ; ++i) { // The last operation (which comes earliest in the IR) is special as both // operands will come from Ops, rather than just one with the other being // a subexpression. - if (i+2 == Ops.size()) { + if (ComExprNum <= 1 && i+2 == TotalEleNum) { Value *NewLHS = Ops[i].Op; Value *NewRHS = Ops[i+1].Op; Value *OldLHS = Op->getOperand(0); @@ -684,7 +688,7 @@ break; if (NewLHS == OldRHS && NewRHS == OldLHS) { - // The order of the operands was reversed. Swap them. + // The order of the operands was reversed. Swap them. LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); Op->swapOperands(); LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); @@ -715,6 +719,136 @@ ++NumChanged; break; + } else if (ComExprNum > 1 && (i + 2 * ComExprNum) == TotalEleNum) { + bool IsAllOpsInPairs = (2 * ComExprNum == TotalEleNum); + BinaryOperator *BoundaryInst = Op; + BinaryOperator *BoundaryInstUser = nullptr; + ExpressionChanged = Op; + MadeChange = true; + + // If all Ops are in pairs, this is the very beginning of rewriting. + if (IsAllOpsInPairs) { + Value *LHS = Op->getOperand(0); + Value *RHS = Op->getOperand(1); + BinaryOperator *LBO = isReassociableOp(LHS, Opcode); + BinaryOperator *RBO = isReassociableOp(RHS, Opcode); + if (LBO && !NotRewritable.count(LBO) && !RewriteExprs.count(LBO)) { + NodesToRewrite.push_back(LBO); + RewriteExprs.insert(LBO); + } + if (RBO && !NotRewritable.count(RBO) && !RewriteExprs.count(RBO)) { + NodesToRewrite.push_back(RBO); + RewriteExprs.insert(RBO); + } + } else { + // Record BoundaryInstUser in case it is transformed during rewriting + // expressions for common pairs. + BoundaryInstUser = cast(*BoundaryInst->user_begin()); + // If this is not the very beginning, push Op into NodesToRewrite for + // later handling. + if (!NotRewritable.count(Op) && !RewriteExprs.count(Op)) { + NodesToRewrite.push_back(Op); + RewriteExprs.insert(Op); + } + } + + // First rewrite all pairs. + for (unsigned i = 0; i < ComExprNum; i++) { + Op = NodesToRewrite.pop_back_val(); + assert(Op && "There should be more rewriting nodes."); + Value *OldLHS = Op->getOperand(0); + Value *OldRHS = Op->getOperand(1); + unsigned CurPairIdx = TotalEleNum - 2 * (ComExprNum - i); + Value *NewLHS = Ops[CurPairIdx].Op; + Value *NewRHS = Ops[CurPairIdx + 1].Op; + if (NewLHS == OldLHS && NewRHS == OldRHS) { + PairResults.push_back(Op); + continue; + } + + if (NewLHS == OldRHS && NewRHS == OldLHS) { + // The order of the operands was reversed. Swap them. + LLVM_DEBUG(dbgs() << "RA-Pair: " << *Op << '\n'); + Op->swapOperands(); + PairResults.push_back(Op); + ++NumChanged; + LLVM_DEBUG(dbgs() << "TO-Pair: " << *Op << '\n'); + continue; + } + + LLVM_DEBUG(dbgs() << "RA-Pair: " << *Op << '\n'); + if (NewLHS != OldLHS) { + BinaryOperator *BO = isReassociableOp(OldLHS, Opcode); + if (BO && !NotRewritable.count(BO) && !RewriteExprs.count(BO)) { + NodesToRewrite.push_back(BO); + RewriteExprs.insert(BO); + } + Op->setOperand(0, NewLHS); + } + if (NewRHS != OldRHS) { + BinaryOperator *BO = isReassociableOp(OldRHS, Opcode); + if (BO && !NotRewritable.count(BO) && !RewriteExprs.count(BO)) { + NodesToRewrite.push_back(BO); + RewriteExprs.insert(BO); + } + Op->setOperand(1, NewRHS); + } + PairResults.push_back(Op); + + ++NumChanged; + LLVM_DEBUG(dbgs() << "TO-Pair: " << *Op << '\n'); + } + + // Second, rewrite all expressions for pair result. + assert((PairResults.size() >= 2) && "There should be at least 2 pairs."); + SmallVector PairResultsWorklist(PairResults); + while (!(PairResultsWorklist.size() == 1 || + (PairResultsWorklist.size() == 2 && IsAllOpsInPairs))) { + Op = NodesToRewrite.pop_back_val(); + assert(Op && "There should be more rewriting nodes."); + ++NumChanged; + Value *NewLHS = PairResultsWorklist.pop_back_val(); + Value *NewRHS = PairResultsWorklist.pop_back_val(); + LLVM_DEBUG(dbgs() << "RA-PairResult: " << *Op << '\n'); + Value *OldLHS = Op->getOperand(0); + Value *OldRHS = Op->getOperand(1); + BinaryOperator *BOL = isReassociableOp(OldLHS, Opcode); + if (BOL && !NotRewritable.count(BOL) && !RewriteExprs.count(BOL)) { + NodesToRewrite.push_back(BOL); + RewriteExprs.insert(BOL); + } + Op->setOperand(0, NewLHS); + BinaryOperator *BOR = isReassociableOp(OldRHS, Opcode); + if (BOR && !NotRewritable.count(BOR) && !RewriteExprs.count(BOR)) { + NodesToRewrite.push_back(BOR); + RewriteExprs.insert(BOR); + } + Op->setOperand(1, NewRHS); + LLVM_DEBUG(dbgs() << "TO-PairResult: " << *Op << '\n'); + PairResultsWorklist.push_back(Op); + PairResults.push_back(Op); + } + + // Third, rewrite boundary instruciton to connect pair instructions with + // non-pair instructions. + ++NumChanged; + if (IsAllOpsInPairs) { + LLVM_DEBUG(dbgs() << "RA-Boundary: " << *BoundaryInst << '\n'); + assert((PairResultsWorklist.size() == 2) && + "Worklist size is not right."); + BoundaryInst->setOperand(0, PairResultsWorklist.pop_back_val()); + BoundaryInst->setOperand(1, PairResultsWorklist.pop_back_val()); + LLVM_DEBUG(dbgs() << "TO-Boundary: " << *BoundaryInst << '\n'); + } else { + LLVM_DEBUG(dbgs() << "RA-Boundary: " << *BoundaryInstUser << '\n'); + assert((PairResultsWorklist.size() == 1) && + "Worklist size is not right."); + assert(BoundaryInstUser && "BoundaryInst has no user."); + BoundaryInstUser->setOperand(0, PairResultsWorklist.pop_back_val()); + LLVM_DEBUG(dbgs() << "RA-Boundary: " << *BoundaryInstUser << '\n'); + } + + break; } // Not the last operation. The left-hand side will be a sub-expression @@ -775,29 +909,45 @@ Op = NewOp; } - // If the expression changed non-trivially then clear out all subclass data - // starting from the operator specified in ExpressionChanged, and compactify - // the operators to just before the expression root to guarantee that the - // expression tree is dominated by all of Ops. - if (ExpressionChanged) - do { + auto CleanUpAfterRewrite = [&] (BinaryOperator *Op) { // Preserve FastMathFlags. if (isa(I)) { FastMathFlags Flags = I->getFastMathFlags(); - ExpressionChanged->clearSubclassOptionalData(); - ExpressionChanged->setFastMathFlags(Flags); + Op->clearSubclassOptionalData(); + Op->setFastMathFlags(Flags); } else - ExpressionChanged->clearSubclassOptionalData(); + Op->clearSubclassOptionalData(); - if (ExpressionChanged == I) - break; + //Is the cleanup finished? + if (Op == I) + return true; // Discard any debug info related to the expressions that has changed (we // can leave debug infor related to the root, since the result of the // expression tree should be the same even after reassociation). - replaceDbgUsesWithUndef(ExpressionChanged); + replaceDbgUsesWithUndef(Op); + + // Move Op just before I + Op->moveBefore(I); + return false; + }; + + // Clean up for PairResults + for (auto* It : PairResults) { + bool IsFinished = CleanUpAfterRewrite(It); + // There is no possible of finishing cleanup in pair results. Always need to + // clean up for I and it is never be in PairResults. + assert(!IsFinished && "Should not finish in pair result."); + } + // If the expression changed non-trivially then clear out all subclass data + // starting from the operator specified in ExpressionChanged, and compactify + // the operators to just before the expression root to guarantee that the + // expression tree is dominated by all of Ops. + if (ExpressionChanged) + do { + if (CleanUpAfterRewrite(ExpressionChanged)) + break; - ExpressionChanged->moveBefore(I); ExpressionChanged = cast(*ExpressionChanged->user_begin()); } while (true); @@ -2361,9 +2511,10 @@ } } } + LLVM_DEBUG(dbgs() << "Common pair number is " << ChosenPairs.size() << "\n"); // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops); + RewriteExprTree(I, Ops, ChosenPairs.size()); } void diff --git a/llvm/test/Transforms/Reassociate/cse-pairs.ll b/llvm/test/Transforms/Reassociate/cse-pairs.ll --- a/llvm/test/Transforms/Reassociate/cse-pairs.ll +++ b/llvm/test/Transforms/Reassociate/cse-pairs.ll @@ -9,13 +9,12 @@ define signext i32 @twoPairs(i32 signext %0, i32 signext %1, i32 signext %2, i32 signext %3, i32 signext %4) { ; CHECK-LABEL: @twoPairs( ; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP2:%.*]], [[TMP0:%.*]] -; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], [[TMP1:%.*]] -; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], [[TMP3:%.*]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP3:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP6]], [[TMP7]] ; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP8]], [[TMP4:%.*]] ; CHECK-NEXT: store i32 [[TMP9]], i32* @num1, align 4 ; CHECK-NEXT: store i32 [[TMP6]], i32* @num2, align 4 -; CHECK-NEXT: [[TMP10:%.*]] = add nsw i32 [[TMP3]], [[TMP1]] -; CHECK-NEXT: store i32 [[TMP10]], i32* @num3, align 4 +; CHECK-NEXT: store i32 [[TMP7]], i32* @num3, align 4 ; CHECK-NEXT: ret i32 undef ; %6 = add i32 %2, %0 @@ -32,13 +31,12 @@ define signext i32 @twoPairsAllOpInPairs(i32 signext %0, i32 signext %1, i32 signext %2, i32 signext %3) { ; CHECK-LABEL: @twoPairsAllOpInPairs( -; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP2:%.*]], [[TMP1:%.*]] -; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP5]], [[TMP0:%.*]] -; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], [[TMP3:%.*]] +; CHECK-NEXT: [[TMP5:%.*]] = add i32 [[TMP3:%.*]], [[TMP0:%.*]] +; CHECK-NEXT: [[TMP6:%.*]] = add i32 [[TMP2:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP6]], [[TMP5]] ; CHECK-NEXT: store i32 [[TMP7]], i32* @num1, align 4 -; CHECK-NEXT: store i32 [[TMP5]], i32* @num2, align 4 -; CHECK-NEXT: [[TMP8:%.*]] = add nsw i32 [[TMP3]], [[TMP0]] -; CHECK-NEXT: store i32 [[TMP8]], i32* @num3, align 4 +; CHECK-NEXT: store i32 [[TMP6]], i32* @num2, align 4 +; CHECK-NEXT: store i32 [[TMP5]], i32* @num3, align 4 ; CHECK-NEXT: ret i32 undef ; %5 = add nsw i32 %0, %1 @@ -54,17 +52,15 @@ define signext i32 @threePairsAllOpInPairs(i32 signext %0, i32 signext %1, i32 signext %2, i32 signext %3, i32 signext %4, i32 signext %5) { ; CHECK-LABEL: @threePairsAllOpInPairs( -; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP3:%.*]], [[TMP2:%.*]] -; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP7]], [[TMP0:%.*]] -; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP8]], [[TMP1:%.*]] -; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], [[TMP4:%.*]] -; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP10]], [[TMP5:%.*]] +; CHECK-NEXT: [[TMP7:%.*]] = add i32 [[TMP5:%.*]], [[TMP0:%.*]] +; CHECK-NEXT: [[TMP8:%.*]] = add i32 [[TMP4:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP3:%.*]], [[TMP2:%.*]] +; CHECK-NEXT: [[TMP10:%.*]] = add i32 [[TMP9]], [[TMP8]] +; CHECK-NEXT: [[TMP11:%.*]] = add i32 [[TMP10]], [[TMP7]] ; CHECK-NEXT: store i32 [[TMP11]], i32* @num1, align 4 -; CHECK-NEXT: [[TMP12:%.*]] = add nsw i32 [[TMP5]], [[TMP0]] -; CHECK-NEXT: store i32 [[TMP12]], i32* @num2, align 4 -; CHECK-NEXT: [[TMP13:%.*]] = add nsw i32 [[TMP4]], [[TMP1]] -; CHECK-NEXT: store i32 [[TMP13]], i32* @num3, align 4 -; CHECK-NEXT: store i32 [[TMP7]], i32* @num4, align 4 +; CHECK-NEXT: store i32 [[TMP7]], i32* @num2, align 4 +; CHECK-NEXT: store i32 [[TMP8]], i32* @num3, align 4 +; CHECK-NEXT: store i32 [[TMP9]], i32* @num4, align 4 ; CHECK-NEXT: ret i32 undef ; %7 = add nsw i32 %0, %1