Index: include/llvm/Transforms/Scalar/Reassociate.h =================================================================== --- include/llvm/Transforms/Scalar/Reassociate.h +++ include/llvm/Transforms/Scalar/Reassociate.h @@ -23,77 +23,13 @@ #ifndef LLVM_TRANSFORMS_SCALAR_REASSOCIATE_H #define LLVM_TRANSFORMS_SCALAR_REASSOCIATE_H -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" namespace llvm { - -/// A private "module" namespace for types and utilities used by Reassociate. -/// These are implementation details and should not be used by clients. -namespace reassociate { -struct ValueEntry { - unsigned Rank; - Value *Op; - ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} -}; -inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { - return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. -} - -/// \brief Utility class representing a base and exponent pair which form one -/// factor of some product. -struct Factor { - Value *Base; - unsigned Power; - Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} -}; - -class XorOpnd; -} - /// Reassociate commutative expressions. class ReassociatePass : public PassInfoMixin { - DenseMap RankMap; - DenseMap, unsigned> ValueRankMap; - SetVector> RedoInsts; - bool MadeChange; - public: - PreservedAnalyses run(Function &F, FunctionAnalysisManager &); - -private: - void BuildRankMap(Function &F, ReversePostOrderTraversal &RPOT); - unsigned getRank(Value *V); - void canonicalizeOperands(Instruction *I); - void ReassociateExpression(BinaryOperator *I); - void RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops); - Value *OptimizeExpression(BinaryOperator *I, - SmallVectorImpl &Ops); - Value *OptimizeAdd(Instruction *I, - SmallVectorImpl &Ops); - Value *OptimizeXor(Instruction *I, - SmallVectorImpl &Ops); - bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, - APInt &ConstOpnd, Value *&Res); - bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, - reassociate::XorOpnd *Opnd2, APInt &ConstOpnd, - Value *&Res); - bool collectMultiplyFactors(SmallVectorImpl &Ops, - SmallVectorImpl &Factors); - Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder, - SmallVectorImpl &Factors); - Value *OptimizeMul(BinaryOperator *I, - SmallVectorImpl &Ops); - Value *RemoveFactorFromExpression(Value *V, Value *Factor); - void EraseInst(Instruction *I); - void RecursivelyEraseDeadInsts(Instruction *I, - SetVector> &Insts); - void OptimizeInst(Instruction *I); - Instruction *canonicalizeNegConstExpr(Instruction *I); + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; } Index: lib/Transforms/Scalar/Reassociate.cpp =================================================================== --- lib/Transforms/Scalar/Reassociate.cpp +++ lib/Transforms/Scalar/Reassociate.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -43,7 +44,6 @@ #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; -using namespace reassociate; #define DEBUG_TYPE "reassociate" @@ -51,20 +51,25 @@ STATISTIC(NumAnnihil, "Number of expr tree annihilated"); STATISTIC(NumFactor , "Number of multiplies factored"); -#ifndef NDEBUG -/// Print out the expression identified in the Ops list. -/// -static void PrintOps(Instruction *I, const SmallVectorImpl &Ops) { - Module *M = I->getModule(); - dbgs() << Instruction::getOpcodeName(I->getOpcode()) << " " - << *Ops[0].Op->getType() << '\t'; - for (unsigned i = 0, e = Ops.size(); i != e; ++i) { - dbgs() << "[ "; - Ops[i].Op->printAsOperand(dbgs(), false, M); - dbgs() << ", #" << Ops[i].Rank << "] "; - } +namespace llvm { +namespace reassociate { + +struct ValueEntry { + unsigned Rank; + Value *Op; + ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} +}; +inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { + return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. } -#endif + +/// \brief Utility class representing a base and exponent pair which form one +/// factor of some product. +struct Factor { + Value *Base; + unsigned Power; + Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} +}; /// Utility class representing a non-constant Xor-operand. We classify /// non-constant Xor-Operands into two categories: @@ -74,7 +79,7 @@ /// constant. /// C2.2) Any operand E which doesn't fall into C1 and C2.1, we view this /// operand as "E | 0" -class llvm::reassociate::XorOpnd { +class XorOpnd { public: XorOpnd(Value *V); @@ -122,6 +127,25 @@ ConstPart = APInt::getNullValue(V->getType()->getIntegerBitWidth()); isOr = true; } +} // namespace reassociate +} // namespace llvm + +using namespace llvm::reassociate; + +#ifndef NDEBUG +/// Print out the expression identified in the Ops list. +/// +static void PrintOps(Instruction *I, const SmallVectorImpl &Ops) { + Module *M = I->getModule(); + dbgs() << Instruction::getOpcodeName(I->getOpcode()) << " " + << *Ops[0].Op->getType() << '\t'; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + dbgs() << "[ "; + Ops[i].Op->printAsOperand(dbgs(), false, M); + dbgs() << ", #" << Ops[i].Rank << "] "; + } +} +#endif /// Return true if V is an instruction of the specified opcode and if it /// only has one use. @@ -145,8 +169,52 @@ return nullptr; } -void ReassociatePass::BuildRankMap(Function &F, - ReversePostOrderTraversal &RPOT) { +/// Reassociate commutative expressions. +class Reassociate { + DenseMap RankMap; + DenseMap, unsigned> ValueRankMap; + SetVector> RedoInsts; + bool MadeChange; + DominatorTree &DT; + +public: + Reassociate(DominatorTree &DT) : DT(DT) {} + PreservedAnalyses run(Function &F); + +private: + void BuildRankMap(Function &F, ReversePostOrderTraversal &RPOT); + unsigned getRank(Value *V); + void canonicalizeOperands(Instruction *I); + void ReassociateExpression(BinaryOperator *I); + void RewriteExprTree(BinaryOperator *I, + SmallVectorImpl &Ops); + Value *OptimizeExpression(BinaryOperator *I, + SmallVectorImpl &Ops); + Value *OptimizeAdd(Instruction *I, + SmallVectorImpl &Ops); + Value *OptimizeXor(Instruction *I, + SmallVectorImpl &Ops); + bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, + APInt &ConstOpnd, Value *&Res); + bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, + reassociate::XorOpnd *Opnd2, APInt &ConstOpnd, + Value *&Res); + bool collectMultiplyFactors(SmallVectorImpl &Ops, + SmallVectorImpl &Factors); + Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl &Factors); + Value *OptimizeMul(BinaryOperator *I, + SmallVectorImpl &Ops); + Value *RemoveFactorFromExpression(Value *V, Value *Factor); + void EraseInst(Instruction *I); + void RecursivelyEraseDeadInsts(Instruction *I, + SetVector> &Insts); + void OptimizeInst(Instruction *I); + Instruction *canonicalizeNegConstExpr(Instruction *I); +}; + +void Reassociate::BuildRankMap(Function &F, + ReversePostOrderTraversal &RPOT) { unsigned i = 2; // Assign distinct ranks to function arguments. @@ -168,7 +236,7 @@ } } -unsigned ReassociatePass::getRank(Value *V) { +unsigned Reassociate::getRank(Value *V) { Instruction *I = dyn_cast(V); if (!I) { if (isa(V)) return ValueRankMap[V]; // Function argument. @@ -199,7 +267,7 @@ } // Canonicalize constants to RHS. Otherwise, sort the operands by rank. -void ReassociatePass::canonicalizeOperands(Instruction *I) { +void Reassociate::canonicalizeOperands(Instruction *I) { assert(isa(I) && "Expected binary operator."); assert(I->isCommutative() && "Expected commutative operator."); @@ -610,7 +678,7 @@ /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. -void ReassociatePass::RewriteExprTree(BinaryOperator *I, +void Reassociate::RewriteExprTree(BinaryOperator *I, SmallVectorImpl &Ops) { assert(Ops.size() > 1 && "Single values should be used directly!"); @@ -754,9 +822,11 @@ } // If the expression changed non-trivially then clear out all subclass data - // starting from the operator specified in ExpressionChanged, and compactify - // the operators to just before the expression root to guarantee that the - // expression tree is dominated by all of Ops. + // starting from the operator specified in ExpressionChanged. If the operators + // of ExpressionChanged are the only use, move it after the def of the operand + // to minimize live-range. Otherwise, compactify the operators to just before + // the expression root to guarantee that the expression tree is dominated by + // all of Ops. if (ExpressionChanged) do { // Preserve FastMathFlags. @@ -769,7 +839,24 @@ if (ExpressionChanged == I) break; - ExpressionChanged->moveBefore(I); + + Instruction *Insert = I; + for (Value *V : ExpressionChanged->operands()) { + // After the reassociation, there will be one dead use of V to be + // cleaned up. Thus check against 2 to see if it's the only use. + if (V->getNumUses() > 2) { + Insert = I; + break; + } + if (Instruction *Inst = dyn_cast(V)) + if (Insert == I || DT.dominates(Insert, Inst)) + Insert = Inst; + } + if (Insert != I) + do { + Insert = Insert->getNextNode(); + } while (isa(Insert)); + ExpressionChanged->moveBefore(Insert); ExpressionChanged = cast(*ExpressionChanged->user_begin()); } while (1); @@ -980,21 +1067,34 @@ } /// Emit a tree of add instructions, summing Ops together -/// and returning the result. Insert the tree before I. +/// and returning the result. If there is only one use of the operand, +/// insert the use right after the def to minimize live-range. +/// Otherwise, insert the tree before I. static Value *EmitAddTreeOfValues(Instruction *I, SmallVectorImpl &Ops){ - if (Ops.size() == 1) return Ops.back(); - - Value *V1 = Ops.back(); - Ops.pop_back(); - Value *V2 = EmitAddTreeOfValues(I, Ops); - return CreateAdd(V2, V1, "tmp", I, I); + Value *Prev = Ops[0]; + bool InsertToI = Prev->getNumUses() > 1; + for (unsigned i = 1; i < Ops.size(); i++) { + Instruction *Insert = I; + if (!InsertToI) { + if (Ops[i]->getNumUses() > 1) + InsertToI = true; + else if (Instruction *Inst = dyn_cast(Ops[i])) + Insert = Inst->getNextNode(); + else if (Instruction *Inst = dyn_cast(Prev)) + Insert = Inst->getNextNode(); + else + InsertToI = true; + } + Prev = CreateAdd(Prev, Ops[i], "tmp", Insert, I); + } + return Prev; } /// If V is an expression tree that is a multiplication sequence, /// and if this sequence contains a multiply by Factor, /// remove Factor from the tree and return the new tree. -Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { +Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); if (!BO) return nullptr; @@ -1157,7 +1257,7 @@ // via "Res" and "ConstOpnd", respectively; otherwise, false is returned, // and both "Res" and "ConstOpnd" remain unchanged. // -bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, +bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt &ConstOpnd, Value *&Res) { // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 // = ((x | c1) ^ c1) ^ (c1 ^ c2) @@ -1192,7 +1292,7 @@ // via "Res" and "ConstOpnd", respectively (If the entire expression is // evaluated to a constant, the Res is set to NULL); otherwise, false is // returned, and both "Res" and "ConstOpnd" remain unchanged. -bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, +bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2, APInt &ConstOpnd, Value *&Res) { Value *X = Opnd1->getSymbolicPart(); @@ -1268,7 +1368,7 @@ /// Optimize a series of operands to an 'xor' instruction. If it can be reduced /// to a single Value, it is returned, otherwise the Ops list is mutated as /// necessary. -Value *ReassociatePass::OptimizeXor(Instruction *I, +Value *Reassociate::OptimizeXor(Instruction *I, SmallVectorImpl &Ops) { if (Value *V = OptimizeAndOrXor(Instruction::Xor, Ops)) return V; @@ -1389,7 +1489,7 @@ /// Optimize a series of operands to an 'add' instruction. This /// optimizes based on identities. If it can be reduced to a single Value, it /// is returned, otherwise the Ops list is mutated as necessary. -Value *ReassociatePass::OptimizeAdd(Instruction *I, +Value *Reassociate::OptimizeAdd(Instruction *I, SmallVectorImpl &Ops) { // Scan the operand lists looking for X and -X pairs. If we find any, we // can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it, @@ -1585,6 +1685,16 @@ delete DummyInst; unsigned NumAddedValues = NewMulOps.size(); + std::stable_sort(NewMulOps.begin(), NewMulOps.end(), + [&](const WeakVH &a, const WeakVH &b) { + const Instruction *ia = dyn_cast(a); + const Instruction *ib = dyn_cast(b); + if (ia == nullptr) + return false; + if (ib == nullptr) + return true; + return DT.dominates(ia, ib); + }); Value *V = EmitAddTreeOfValues(I, NewMulOps); // Now that we have inserted the add tree, optimize it. This allows us to @@ -1627,8 +1737,8 @@ /// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)] /// /// \returns Whether any factors have a power greater than one. -bool ReassociatePass::collectMultiplyFactors(SmallVectorImpl &Ops, - SmallVectorImpl &Factors) { +bool Reassociate::collectMultiplyFactors(SmallVectorImpl &Ops, + SmallVectorImpl &Factors) { // FIXME: Have Ops be (ValueEntry, Multiplicity) pairs, simplifying this. // Compute the sum of powers of simplifiable factors. unsigned FactorPowerSum = 0; @@ -1704,9 +1814,8 @@ /// equal and the powers are sorted in decreasing order, compute the minimal /// DAG of multiplies to compute the final product, and return that product /// value. -Value * -ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder, - SmallVectorImpl &Factors) { +Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl &Factors) { assert(Factors[0].Power); SmallVector OuterProduct; for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size(); @@ -1762,8 +1871,8 @@ return V; } -Value *ReassociatePass::OptimizeMul(BinaryOperator *I, - SmallVectorImpl &Ops) { +Value *Reassociate::OptimizeMul(BinaryOperator *I, + SmallVectorImpl &Ops) { // We can only optimize the multiplies when there is a chain of more than // three, such that a balanced tree might require fewer total multiplies. if (Ops.size() < 4) @@ -1792,8 +1901,8 @@ return nullptr; } -Value *ReassociatePass::OptimizeExpression(BinaryOperator *I, - SmallVectorImpl &Ops) { +Value *Reassociate::OptimizeExpression(BinaryOperator *I, + SmallVectorImpl &Ops) { // Now that we have the linearized expression tree, try to optimize it. // Start by folding any constants that we found. Constant *Cst = nullptr; @@ -1853,7 +1962,7 @@ // Remove dead instructions and if any operands are trivially dead add them to // Insts so they will be removed as well. -void ReassociatePass::RecursivelyEraseDeadInsts( +void Reassociate::RecursivelyEraseDeadInsts( Instruction *I, SetVector> &Insts) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); SmallVector Ops(I->op_begin(), I->op_end()); @@ -1868,7 +1977,7 @@ } /// Zap the given instruction, adding interesting operands to the work list. -void ReassociatePass::EraseInst(Instruction *I) { +void Reassociate::EraseInst(Instruction *I) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); DEBUG(dbgs() << "Erasing dead inst: "; I->dump()); @@ -1894,7 +2003,7 @@ // Canonicalize expressions of the following form: // x + (-Constant * y) -> x - (Constant * y) // x - (-Constant * y) -> x + (Constant * y) -Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { +Instruction *Reassociate::canonicalizeNegConstExpr(Instruction *I) { if (!I->hasOneUse() || I->getType()->isVectorTy()) return nullptr; @@ -1971,7 +2080,7 @@ /// Inspect and optimize the given instruction. Note that erasing /// instructions is not allowed. -void ReassociatePass::OptimizeInst(Instruction *I) { +void Reassociate::OptimizeInst(Instruction *I) { // Only consider operations that we understand. if (!isa(I)) return; @@ -2098,7 +2207,7 @@ ReassociateExpression(BO); } -void ReassociatePass::ReassociateExpression(BinaryOperator *I) { +void Reassociate::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector Tree; @@ -2180,7 +2289,7 @@ RewriteExprTree(I, Ops); } -PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { +PreservedAnalyses Reassociate::run(Function &F) { // Get the functions basic blocks in Reverse Post Order. This order is used by // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic // blocks (it has been seen that the analysis in this pass could hang when @@ -2244,34 +2353,43 @@ return PreservedAnalyses::all(); } +PreservedAnalyses ReassociatePass::run(Function &F, + FunctionAnalysisManager &AM) { + Reassociate impl(AM.getResult(F)); + return impl.run(F); +} + namespace { - class ReassociateLegacyPass : public FunctionPass { - ReassociatePass Impl; - public: - static char ID; // Pass identification, replacement for typeid - ReassociateLegacyPass() : FunctionPass(ID) { - initializeReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); - } +class ReassociateLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + ReassociateLegacyPass() : FunctionPass(ID) { + initializeReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); + } - bool runOnFunction(Function &F) override { - if (skipFunction(F)) - return false; + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; - FunctionAnalysisManager DummyFAM; - auto PA = Impl.run(F, DummyFAM); - return !PA.areAllPreserved(); - } + Reassociate impl(getAnalysis().getDomTree()); + auto PA = impl.run(F); + return !PA.areAllPreserved(); + } - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addPreserved(); - } - }; -} + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addPreserved(); + } +}; +} // namespace char ReassociateLegacyPass::ID = 0; -INITIALIZE_PASS(ReassociateLegacyPass, "reassociate", - "Reassociate expressions", false, false) +INITIALIZE_PASS_BEGIN(ReassociateLegacyPass, "reassociate", + "Reassociate expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ReassociateLegacyPass, "reassociate", + "Reassociate expressions", false, false) // Public interface to the Reassociate pass FunctionPass *llvm::createReassociatePass() { Index: test/Transforms/Reassociate/hoist.ll =================================================================== --- /dev/null +++ test/Transforms/Reassociate/hoist.ll @@ -0,0 +1,93 @@ +; RUN: opt < %s -reassociate -S | FileCheck %s +; RUN: opt < %s -passes='reassociate' -S | FileCheck %s + +declare i32 @foo() +declare void @bar(i64) + +; The add should stay close of its operand def. +;CHECK-LABEL: @test1 +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: call +;CHECK: add +define i64 @test1() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = mul i64 4, %2 + %4 = add i64 0, %3 + %5 = call i32 @foo() + %6 = zext i32 %5 to i64 + %7 = mul i64 4, %6 + %8 = add i64 %4, %7 + %9 = call i32 @foo() + %10 = zext i32 %9 to i64 + %11 = mul i64 4, %10 + %12 = add i64 %8, %11 + ret i64 %12 +} + +; There are 2 uses of %2, thus the add will not be hoisted. +;CHECK-LABEL: @test1_nohoist +;CHECK: call +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: add +define i64 @test1_nohoist() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = mul i64 4, %2 + %4 = add i64 0, %3 + %5 = call i32 @foo() + %6 = zext i32 %5 to i64 + %7 = mul i64 4, %6 + %8 = add i64 %4, %7 + %9 = call i32 @foo() + %10 = zext i32 %9 to i64 + %11 = mul i64 4, %10 + %12 = add i64 %8, %11 + call void @bar(i64 %2) + ret i64 %12 +} + +; The add should stay close of its operand def. +;CHECK-LABEL: @test2 +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: call +;CHECK: add +define i64 @test2() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = add i64 0, %2 + %4 = call i32 @foo() + %5 = zext i32 %4 to i64 + %6 = add i64 %3, %5 + %7 = call i32 @foo() + %8 = zext i32 %7 to i64 + %9 = add i64 %6, %8 + ret i64 %9 +} + +; There are 2 uses of %2, thus the add will not be hoisted. +;CHECK-LABEL: @test2_nohoist +;CHECK: call +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: add +define i64 @test2_nohoist() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = add i64 0, %2 + %4 = call i32 @foo() + %5 = zext i32 %4 to i64 + %6 = add i64 %3, %5 + %7 = call i32 @foo() + %8 = zext i32 %7 to i64 + %9 = add i64 %6, %8 + call void @bar(i64 %2) + ret i64 %9 +}