Index: include/llvm/Transforms/Scalar/Reassociate.h =================================================================== --- include/llvm/Transforms/Scalar/Reassociate.h +++ include/llvm/Transforms/Scalar/Reassociate.h @@ -23,77 +23,13 @@ #ifndef LLVM_TRANSFORMS_SCALAR_REASSOCIATE_H #define LLVM_TRANSFORMS_SCALAR_REASSOCIATE_H -#include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" namespace llvm { - -/// A private "module" namespace for types and utilities used by Reassociate. -/// These are implementation details and should not be used by clients. -namespace reassociate { -struct ValueEntry { - unsigned Rank; - Value *Op; - ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} -}; -inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { - return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. -} - -/// \brief Utility class representing a base and exponent pair which form one -/// factor of some product. -struct Factor { - Value *Base; - unsigned Power; - Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} -}; - -class XorOpnd; -} - /// Reassociate commutative expressions. class ReassociatePass : public PassInfoMixin { - DenseMap RankMap; - DenseMap, unsigned> ValueRankMap; - SetVector> RedoInsts; - bool MadeChange; - public: - PreservedAnalyses run(Function &F, FunctionAnalysisManager &); - -private: - void BuildRankMap(Function &F, ReversePostOrderTraversal &RPOT); - unsigned getRank(Value *V); - void canonicalizeOperands(Instruction *I); - void ReassociateExpression(BinaryOperator *I); - void RewriteExprTree(BinaryOperator *I, - SmallVectorImpl &Ops); - Value *OptimizeExpression(BinaryOperator *I, - SmallVectorImpl &Ops); - Value *OptimizeAdd(Instruction *I, - SmallVectorImpl &Ops); - Value *OptimizeXor(Instruction *I, - SmallVectorImpl &Ops); - bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, - APInt &ConstOpnd, Value *&Res); - bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, - reassociate::XorOpnd *Opnd2, APInt &ConstOpnd, - Value *&Res); - bool collectMultiplyFactors(SmallVectorImpl &Ops, - SmallVectorImpl &Factors); - Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder, - SmallVectorImpl &Factors); - Value *OptimizeMul(BinaryOperator *I, - SmallVectorImpl &Ops); - Value *RemoveFactorFromExpression(Value *V, Value *Factor); - void EraseInst(Instruction *I); - void RecursivelyEraseDeadInsts(Instruction *I, - SetVector> &Insts); - void OptimizeInst(Instruction *I); - Instruction *canonicalizeNegConstExpr(Instruction *I); + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; } Index: lib/Transforms/Scalar/Reassociate.cpp =================================================================== --- lib/Transforms/Scalar/Reassociate.cpp +++ lib/Transforms/Scalar/Reassociate.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -43,7 +44,6 @@ #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; -using namespace reassociate; #define DEBUG_TYPE "reassociate" @@ -66,6 +66,25 @@ } #endif +namespace llvm { +namespace reassociate { +struct ValueEntry { + unsigned Rank; + Value *Op; + ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {} +}; +inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) { + return LHS.Rank > RHS.Rank; // Sort so that highest rank goes to start. +} + +/// \brief Utility class representing a base and exponent pair which form one +/// factor of some product. +struct Factor { + Value *Base; + unsigned Power; + Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {} +}; + /// Utility class representing a non-constant Xor-operand. We classify /// non-constant Xor-Operands into two categories: /// C1) The operand is in the form "X & C", where C is a constant and C != ~0 @@ -74,7 +93,7 @@ /// constant. /// C2.2) Any operand E which doesn't fall into C1 and C2.1, we view this /// operand as "E | 0" -class llvm::reassociate::XorOpnd { +class XorOpnd { public: XorOpnd(Value *V); @@ -122,6 +141,9 @@ ConstPart = APInt::getNullValue(V->getType()->getIntegerBitWidth()); isOr = true; } +}} // namespace llvm::reassociate + +using namespace llvm::reassociate; /// Return true if V is an instruction of the specified opcode and if it /// only has one use. @@ -145,7 +167,51 @@ return nullptr; } -void ReassociatePass::BuildRankMap(Function &F, +/// Reassociate commutative expressions. +class Reassociate { + DenseMap RankMap; + DenseMap, unsigned> ValueRankMap; + SetVector> RedoInsts; + bool MadeChange; + DominatorTree &DT; + +public: + Reassociate(DominatorTree &DT) : DT(DT) {} + PreservedAnalyses run(Function &F); + +private: + void BuildRankMap(Function &F, ReversePostOrderTraversal &RPOT); + unsigned getRank(Value *V); + void canonicalizeOperands(Instruction *I); + void ReassociateExpression(BinaryOperator *I); + void RewriteExprTree(BinaryOperator *I, + SmallVectorImpl &Ops); + Value *OptimizeExpression(BinaryOperator *I, + SmallVectorImpl &Ops); + Value *OptimizeAdd(Instruction *I, + SmallVectorImpl &Ops); + Value *OptimizeXor(Instruction *I, + SmallVectorImpl &Ops); + bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, + APInt &ConstOpnd, Value *&Res); + bool CombineXorOpnd(Instruction *I, reassociate::XorOpnd *Opnd1, + reassociate::XorOpnd *Opnd2, APInt &ConstOpnd, + Value *&Res); + bool collectMultiplyFactors(SmallVectorImpl &Ops, + SmallVectorImpl &Factors); + Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl &Factors); + Value *OptimizeMul(BinaryOperator *I, + SmallVectorImpl &Ops); + Value *RemoveFactorFromExpression(Value *V, Value *Factor); + void EraseInst(Instruction *I); + void RecursivelyEraseDeadInsts(Instruction *I, + SetVector> &Insts); + void OptimizeInst(Instruction *I); + Instruction *canonicalizeNegConstExpr(Instruction *I); +}; + +void Reassociate::BuildRankMap(Function &F, ReversePostOrderTraversal &RPOT) { unsigned i = 2; @@ -168,7 +234,7 @@ } } -unsigned ReassociatePass::getRank(Value *V) { +unsigned Reassociate::getRank(Value *V) { Instruction *I = dyn_cast(V); if (!I) { if (isa(V)) return ValueRankMap[V]; // Function argument. @@ -199,7 +265,7 @@ } // Canonicalize constants to RHS. Otherwise, sort the operands by rank. -void ReassociatePass::canonicalizeOperands(Instruction *I) { +void Reassociate::canonicalizeOperands(Instruction *I) { assert(isa(I) && "Expected binary operator."); assert(I->isCommutative() && "Expected commutative operator."); @@ -610,7 +676,7 @@ /// Now that the operands for this expression tree are /// linearized and optimized, emit them in-order. -void ReassociatePass::RewriteExprTree(BinaryOperator *I, +void Reassociate::RewriteExprTree(BinaryOperator *I, SmallVectorImpl &Ops) { assert(Ops.size() > 1 && "Single values should be used directly!"); @@ -754,9 +820,11 @@ } // If the expression changed non-trivially then clear out all subclass data - // starting from the operator specified in ExpressionChanged, and compactify - // the operators to just before the expression root to guarantee that the - // expression tree is dominated by all of Ops. + // starting from the operator specified in ExpressionChanged. If the operators + // of ExpressionChanged are the only use, move it after the def of the operand + // to minimize live-range. Otherwise, compactify the operators to just before + // the expression root to guarantee that the expression tree is dominated by + // all of Ops. if (ExpressionChanged) do { // Preserve FastMathFlags. @@ -769,7 +837,22 @@ if (ExpressionChanged == I) break; - ExpressionChanged->moveBefore(I); + + Instruction *Insert = I; + for (Value *V : ExpressionChanged->operands()) { + // After the reassociation, there will be one dead use of V to be + // cleaned up. Thus check against 2 to see if it's the only use. + if (V->getNumUses() > 2) { + Insert = I; + break; + } + if (Instruction *Inst = dyn_cast(V)) + if (Insert == I || DT.dominates(Insert, Inst)) + Insert = Inst; + } + if (Insert != I) + Insert = Insert->getNextNode(); + ExpressionChanged->moveBefore(Insert); ExpressionChanged = cast(*ExpressionChanged->user_begin()); } while (1); @@ -980,21 +1063,32 @@ } /// Emit a tree of add instructions, summing Ops together -/// and returning the result. Insert the tree before I. +/// and returning the result. If there is only one use of the operand, +/// insert the use right after the def to minimize live-range. +/// Otherwise, insert the tree before I. static Value *EmitAddTreeOfValues(Instruction *I, SmallVectorImpl &Ops){ - if (Ops.size() == 1) return Ops.back(); - - Value *V1 = Ops.back(); - Ops.pop_back(); - Value *V2 = EmitAddTreeOfValues(I, Ops); - return CreateAdd(V2, V1, "tmp", I, I); + Value *Prev = Ops.back(); + bool InsertToI = Prev->getNumUses() > 1; + for (int i = Ops.size() - 2; i >= 0; i--) { + Instruction *Insert = I; + if (!InsertToI) { + if (Ops[i]->getNumUses() > 1) + InsertToI = true; + else if (Instruction *Inst = dyn_cast(Ops[i])) + Insert = Inst->getNextNode(); + else + InsertToI = true; + } + Prev = CreateAdd(Prev, Ops[i], "tmp", Insert, I); + } + return Prev; } /// If V is an expression tree that is a multiplication sequence, /// and if this sequence contains a multiply by Factor, /// remove Factor from the tree and return the new tree. -Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { +Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) { BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); if (!BO) return nullptr; @@ -1157,7 +1251,7 @@ // via "Res" and "ConstOpnd", respectively; otherwise, false is returned, // and both "Res" and "ConstOpnd" remain unchanged. // -bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, +bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt &ConstOpnd, Value *&Res) { // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 // = ((x | c1) ^ c1) ^ (c1 ^ c2) @@ -1192,7 +1286,7 @@ // via "Res" and "ConstOpnd", respectively (If the entire expression is // evaluated to a constant, the Res is set to NULL); otherwise, false is // returned, and both "Res" and "ConstOpnd" remain unchanged. -bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, +bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2, APInt &ConstOpnd, Value *&Res) { Value *X = Opnd1->getSymbolicPart(); @@ -1268,7 +1362,7 @@ /// Optimize a series of operands to an 'xor' instruction. If it can be reduced /// to a single Value, it is returned, otherwise the Ops list is mutated as /// necessary. -Value *ReassociatePass::OptimizeXor(Instruction *I, +Value *Reassociate::OptimizeXor(Instruction *I, SmallVectorImpl &Ops) { if (Value *V = OptimizeAndOrXor(Instruction::Xor, Ops)) return V; @@ -1389,7 +1483,7 @@ /// Optimize a series of operands to an 'add' instruction. This /// optimizes based on identities. If it can be reduced to a single Value, it /// is returned, otherwise the Ops list is mutated as necessary. -Value *ReassociatePass::OptimizeAdd(Instruction *I, +Value *Reassociate::OptimizeAdd(Instruction *I, SmallVectorImpl &Ops) { // Scan the operand lists looking for X and -X pairs. If we find any, we // can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it, @@ -1627,7 +1721,7 @@ /// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)] /// /// \returns Whether any factors have a power greater than one. -bool ReassociatePass::collectMultiplyFactors(SmallVectorImpl &Ops, +bool Reassociate::collectMultiplyFactors(SmallVectorImpl &Ops, SmallVectorImpl &Factors) { // FIXME: Have Ops be (ValueEntry, Multiplicity) pairs, simplifying this. // Compute the sum of powers of simplifiable factors. @@ -1705,7 +1799,7 @@ /// DAG of multiplies to compute the final product, and return that product /// value. Value * -ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder, +Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder, SmallVectorImpl &Factors) { assert(Factors[0].Power); SmallVector OuterProduct; @@ -1762,7 +1856,7 @@ return V; } -Value *ReassociatePass::OptimizeMul(BinaryOperator *I, +Value *Reassociate::OptimizeMul(BinaryOperator *I, SmallVectorImpl &Ops) { // We can only optimize the multiplies when there is a chain of more than // three, such that a balanced tree might require fewer total multiplies. @@ -1792,7 +1886,7 @@ return nullptr; } -Value *ReassociatePass::OptimizeExpression(BinaryOperator *I, +Value *Reassociate::OptimizeExpression(BinaryOperator *I, SmallVectorImpl &Ops) { // Now that we have the linearized expression tree, try to optimize it. // Start by folding any constants that we found. @@ -1853,7 +1947,7 @@ // Remove dead instructions and if any operands are trivially dead add them to // Insts so they will be removed as well. -void ReassociatePass::RecursivelyEraseDeadInsts( +void Reassociate::RecursivelyEraseDeadInsts( Instruction *I, SetVector> &Insts) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); SmallVector Ops(I->op_begin(), I->op_end()); @@ -1868,7 +1962,7 @@ } /// Zap the given instruction, adding interesting operands to the work list. -void ReassociatePass::EraseInst(Instruction *I) { +void Reassociate::EraseInst(Instruction *I) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); DEBUG(dbgs() << "Erasing dead inst: "; I->dump()); @@ -1894,7 +1988,7 @@ // Canonicalize expressions of the following form: // x + (-Constant * y) -> x - (Constant * y) // x - (-Constant * y) -> x + (Constant * y) -Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { +Instruction *Reassociate::canonicalizeNegConstExpr(Instruction *I) { if (!I->hasOneUse() || I->getType()->isVectorTy()) return nullptr; @@ -1971,7 +2065,7 @@ /// Inspect and optimize the given instruction. Note that erasing /// instructions is not allowed. -void ReassociatePass::OptimizeInst(Instruction *I) { +void Reassociate::OptimizeInst(Instruction *I) { // Only consider operations that we understand. if (!isa(I)) return; @@ -2098,7 +2192,7 @@ ReassociateExpression(BO); } -void ReassociatePass::ReassociateExpression(BinaryOperator *I) { +void Reassociate::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector Tree; @@ -2180,7 +2274,7 @@ RewriteExprTree(I, Ops); } -PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { +PreservedAnalyses Reassociate::run(Function &F) { // Get the functions basic blocks in Reverse Post Order. This order is used by // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic // blocks (it has been seen that the analysis in this pass could hang when @@ -2244,9 +2338,13 @@ return PreservedAnalyses::all(); } +PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &AM) { + Reassociate impl(AM.getResult(F)); + return impl.run(F); +} + namespace { class ReassociateLegacyPass : public FunctionPass { - ReassociatePass Impl; public: static char ID; // Pass identification, replacement for typeid ReassociateLegacyPass() : FunctionPass(ID) { @@ -2257,21 +2355,25 @@ if (skipFunction(F)) return false; - FunctionAnalysisManager DummyFAM; - auto PA = Impl.run(F, DummyFAM); + Reassociate impl(getAnalysis().getDomTree()); + auto PA = impl.run(F); return !PA.areAllPreserved(); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired(); AU.addPreserved(); } }; } char ReassociateLegacyPass::ID = 0; -INITIALIZE_PASS(ReassociateLegacyPass, "reassociate", - "Reassociate expressions", false, false) +INITIALIZE_PASS_BEGIN(ReassociateLegacyPass, "reassociate", + "Reassociate expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ReassociateLegacyPass, "reassociate", + "Reassociate expressions", false, false) // Public interface to the Reassociate pass FunctionPass *llvm::createReassociatePass() { Index: test/Transforms/Reassociate/hoist.ll =================================================================== --- /dev/null +++ test/Transforms/Reassociate/hoist.ll @@ -0,0 +1,93 @@ +; RUN: opt < %s -reassociate -S | FileCheck %s +; RUN: opt < %s -passes='reassociate' -S | FileCheck %s + +declare i32 @foo() +declare void @bar(i64) + +; The add should stay close of its operand def. +;CHECK-LABEL: @test1 +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: call +;CHECK: add +define i64 @test1() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = mul i64 4, %2 + %4 = add i64 0, %3 + %5 = call i32 @foo() + %6 = zext i32 %5 to i64 + %7 = mul i64 4, %6 + %8 = add i64 %4, %7 + %9 = call i32 @foo() + %10 = zext i32 %9 to i64 + %11 = mul i64 4, %10 + %12 = add i64 %8, %11 + ret i64 %12 +} + +; There are 2 uses of %2, thus the add will not be hoisted. +;CHECK-LABEL: @test1_nohoist +;CHECK: call +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: add +define i64 @test1_nohoist() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = mul i64 4, %2 + %4 = add i64 0, %3 + %5 = call i32 @foo() + %6 = zext i32 %5 to i64 + %7 = mul i64 4, %6 + %8 = add i64 %4, %7 + %9 = call i32 @foo() + %10 = zext i32 %9 to i64 + %11 = mul i64 4, %10 + %12 = add i64 %8, %11 + call void @bar(i64 %2) + ret i64 %12 +} + +; The add should stay close of its operand def. +;CHECK-LABEL: @test2 +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: call +;CHECK: add +define i64 @test2() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = add i64 0, %2 + %4 = call i32 @foo() + %5 = zext i32 %4 to i64 + %6 = add i64 %3, %5 + %7 = call i32 @foo() + %8 = zext i32 %7 to i64 + %9 = add i64 %6, %8 + ret i64 %9 +} + +; There are 2 uses of %2, thus the add will not be hoisted. +;CHECK-LABEL: @test2_nohoist +;CHECK: call +;CHECK: call +;CHECK: call +;CHECK: add +;CHECK: add +define i64 @test2_nohoist() { + %1 = call i32 @foo() + %2 = zext i32 %1 to i64 + %3 = add i64 0, %2 + %4 = call i32 @foo() + %5 = zext i32 %4 to i64 + %6 = add i64 %3, %5 + %7 = call i32 @foo() + %8 = zext i32 %7 to i64 + %9 = add i64 %6, %8 + call void @bar(i64 %2) + ret i64 %9 +}