Index: llvm/include/llvm/Transforms/Scalar/Reassociate.h =================================================================== --- llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -105,7 +105,8 @@ void canonicalizeOperands(Instruction *I); void ReassociateExpression(BinaryOperator *I); void RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops); + SmallVectorImpl &Ops, + unsigned ComExprNum = 0); Value *OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops); Value *OptimizeAdd(Instruction *I, Index: llvm/lib/Transforms/Scalar/Reassociate.cpp =================================================================== --- llvm/lib/Transforms/Scalar/Reassociate.cpp +++ llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -633,7 +633,8 @@ /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. void ReassociatePass::RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops, + unsigned ComExprNum) { assert(Ops.size() > 1 && "Single values should be used directly!"); // Since our optimizations should never increase the number of operations, the @@ -669,11 +670,14 @@ // original in some non-trivial way, requiring the clearing of optional flags. // Flags are cleared from the operator in ExpressionChanged up to I inclusive. BinaryOperator *ExpressionChanged = nullptr; + unsigned TotalEleNum = Ops.size(); + SmallPtrSet RewriteExprs; + SmallVector PairResults; for (unsigned i = 0; ; ++i) { // The last operation (which comes earliest in the IR) is special as both // operands will come from Ops, rather than just one with the other being // a subexpression. - if (i+2 == Ops.size()) { + if (ComExprNum <= 1 && i+2 == TotalEleNum) { Value *NewLHS = Ops[i].Op; Value *NewRHS = Ops[i+1].Op; Value *OldLHS = Op->getOperand(0); @@ -684,7 +688,7 @@ break; if (NewLHS == OldRHS && NewRHS == OldLHS) { - // The order of the operands was reversed. Swap them. + // The order of the operands was reversed. Swap them. LLVM_DEBUG(dbgs() << "RA: " << *Op << '\n'); Op->swapOperands(); LLVM_DEBUG(dbgs() << "TO: " << *Op << '\n'); @@ -715,6 +719,126 @@ ++NumChanged; break; + } else if (ComExprNum > 1 && (i + 2 * ComExprNum) == TotalEleNum) { + bool IsAllOpsInPairs = (2 * ComExprNum == TotalEleNum); + ExpressionChanged = Op; + BinaryOperator *BoundaryInst = Op; + + // If all Ops are in pairs, this is the very beginning of rewriting. + if (IsAllOpsInPairs) { + Value *LHS = Op->getOperand(0); + Value *RHS = Op->getOperand(1); + BinaryOperator *LBO = isReassociableOp(LHS, Opcode); + BinaryOperator *RBO = isReassociableOp(RHS, Opcode); + if (LBO && !NotRewritable.count(LBO) && !RewriteExprs.count(LBO)) { + NodesToRewrite.push_back(LBO); + RewriteExprs.insert(LBO); + } + if (RBO && !NotRewritable.count(RBO) && !RewriteExprs.count(RBO)) { + NodesToRewrite.push_back(RBO); + RewriteExprs.insert(RBO); + } + } + + // First rewrite all pairs. + for (unsigned i = 0; i < ComExprNum; i++) { + Op = NodesToRewrite.pop_back_val(); + assert(Op && "There should be more rewriting nodes."); + Value *OldLHS = Op->getOperand(0); + Value *OldRHS = Op->getOperand(1); + unsigned CurPairIdx = TotalEleNum - 2 * (ComExprNum - i); + Value *NewLHS = Ops[CurPairIdx].Op; + Value *NewRHS = Ops[CurPairIdx + 1].Op; + if (NewLHS == OldLHS && NewRHS == OldRHS) { + PairResults.push_back(Op); + continue; + } + + if (NewLHS == OldRHS && NewRHS == OldLHS) { + // The order of the operands was reversed. Swap them. + LLVM_DEBUG(dbgs() << "RA-Pair: " << *Op << '\n'); + Op->swapOperands(); + PairResults.push_back(Op); + MadeChange = true; + ++NumChanged; + LLVM_DEBUG(dbgs() << "TO-Pair: " << *Op << '\n'); + continue; + } + + LLVM_DEBUG(dbgs() << "RA-Pair: " << *Op << '\n'); + if (NewLHS != OldLHS) { + BinaryOperator *BO = isReassociableOp(OldLHS, Opcode); + if (BO && !NotRewritable.count(BO) && !RewriteExprs.count(BO)) { + NodesToRewrite.push_back(BO); + RewriteExprs.insert(BO); + } + Op->setOperand(0, NewLHS); + } + if (NewRHS != OldRHS) { + BinaryOperator *BO = isReassociableOp(OldRHS, Opcode); + if (BO && !NotRewritable.count(BO) && !RewriteExprs.count(BO)) { + NodesToRewrite.push_back(BO); + RewriteExprs.insert(BO); + } + Op->setOperand(1, NewRHS); + } + PairResults.push_back(Op); + + MadeChange = true; + ++NumChanged; + LLVM_DEBUG(dbgs() << "TO-Pair: " << *Op << '\n'); + } + + // Second, rewrite all expressions for pair result. + assert((PairResults.size() >= 2) && "There should be at least 2 pairs."); + SmallVector PairResultsWorklist(PairResults); + while (!(PairResultsWorklist.size() == 1 || + (PairResultsWorklist.size() == 2 && IsAllOpsInPairs))) { + Op = NodesToRewrite.pop_back_val(); + assert(Op && "There should be more rewriting nodes."); + MadeChange = true; + ++NumChanged; + Value *NewLHS = PairResultsWorklist.pop_back_val(); + Value *NewRHS = PairResultsWorklist.pop_back_val(); + LLVM_DEBUG(dbgs() << "RA-PairResult: " << *Op << '\n'); + Value *OldLHS = Op->getOperand(0); + Value *OldRHS = Op->getOperand(1); + BinaryOperator *BOL = isReassociableOp(OldLHS, Opcode); + if (BOL && !NotRewritable.count(BOL) && !RewriteExprs.count(BOL)) { + NodesToRewrite.push_back(BOL); + RewriteExprs.insert(BOL); + } + Op->setOperand(0, NewLHS); + BinaryOperator *BOR = isReassociableOp(OldRHS, Opcode); + if (BOR && !NotRewritable.count(BOR) && !RewriteExprs.count(BOR)) { + NodesToRewrite.push_back(BOR); + RewriteExprs.insert(BOR); + } + Op->setOperand(1, NewRHS); + LLVM_DEBUG(dbgs() << "TO-PairResult: " << *Op << '\n'); + PairResultsWorklist.push_back(Op); + PairResults.push_back(Op); + } + + // Third, rewrite boundary instruciton to connect pair instructions with + // non-pair instructions. + MadeChange = true; + ++NumChanged; + LLVM_DEBUG(dbgs() << "RA-Boundary: " << *BoundaryInst << '\n'); + if (IsAllOpsInPairs) { + assert((PairResultsWorklist.size() == 2) && + "Worklist size is not right."); + BoundaryInst->setOperand(0, PairResultsWorklist.pop_back_val()); + BoundaryInst->setOperand(1, PairResultsWorklist.pop_back_val()); + } else { + assert((PairResultsWorklist.size() == 1) && + "Worklist size is not right."); + cast(*BoundaryInst->user_begin()) + ->setOperand(0, PairResultsWorklist.pop_back_val()); + } + LLVM_DEBUG(dbgs() << "TO-Boundary: " << *BoundaryInst << '\n'); + + break; } // Not the last operation. The left-hand side will be a sub-expression @@ -775,31 +899,39 @@ Op = NewOp; } - // If the expression changed non-trivially then clear out all subclass data - // starting from the operator specified in ExpressionChanged, and compactify - // the operators to just before the expression root to guarantee that the - // expression tree is dominated by all of Ops. - if (ExpressionChanged) - do { + auto CleanUpAfterRewrite = [&] (BinaryOperator *Op) { + if (Op == I) + return; + // Preserve FastMathFlags. if (isa(I)) { FastMathFlags Flags = I->getFastMathFlags(); - ExpressionChanged->clearSubclassOptionalData(); - ExpressionChanged->setFastMathFlags(Flags); + Op->clearSubclassOptionalData(); + Op->setFastMathFlags(Flags); } else - ExpressionChanged->clearSubclassOptionalData(); - - if (ExpressionChanged == I) - break; + Op->clearSubclassOptionalData(); // Discard any debug info related to the expressions that has changed (we // can leave debug infor related to the root, since the result of the // expression tree should be the same even after reassociation). - replaceDbgUsesWithUndef(ExpressionChanged); + replaceDbgUsesWithUndef(Op); + + // Move Op just before I + Op->moveBefore(I); + }; - ExpressionChanged->moveBefore(I); + // Clean up for PairResults + for (auto* It : PairResults) + CleanUpAfterRewrite(It); + // If the expression changed non-trivially then clear out all subclass data + // starting from the operator specified in ExpressionChanged, and compactify + // the operators to just before the expression root to guarantee that the + // expression tree is dominated by all of Ops. + if (ExpressionChanged) + while (ExpressionChanged != I) { + CleanUpAfterRewrite(ExpressionChanged); ExpressionChanged = cast(*ExpressionChanged->user_begin()); - } while (true); + } // Throw away any left over nodes from the original expression. for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i) @@ -2363,7 +2495,7 @@ } // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. - RewriteExprTree(I, Ops); + RewriteExprTree(I, Ops, ChosenPairs.size()); } void