Index: lib/Transforms/Scalar/Reassociate.cpp =================================================================== --- lib/Transforms/Scalar/Reassociate.cpp +++ lib/Transforms/Scalar/Reassociate.cpp @@ -12,16 +12,19 @@ // // For example: 4 + (x + 5) -> x + (4 + 5) // -// In the implementation of this algorithm, constants are assigned rank = 0, -// function arguments are rank = 1, and other values are assigned ranks -// corresponding to the reverse post order traversal of current function -// (starting at 2), which effectively gives values in deep loops higher rank -// than values not in loops. +// In the implementation of this algorithm, constants are assigned +// rank = 1 << 16, function arguments are distanct rank start from +// (1 << 16) + 3, and other unmovable values are assigned ranks corresponding to +// the reverse post order traversal of current function (starting at +// (BlockRank << 16) + 1), which effectively gives values in deep loops higher +// rank than values not in loops. Rank between 0 and (1 << 16) are reserved for +// pairing previous associated value. // //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/IndexedMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -39,6 +42,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" #include +#include using namespace llvm; #define DEBUG_TYPE "reassociate" @@ -157,11 +161,50 @@ } namespace { + struct OpcodeToIndex : public std::unary_function { + unsigned operator()(unsigned op) const { + switch(op) { + case Instruction::Add: return 0; + case Instruction::Mul: return 1; + case Instruction::And: return 2; + case Instruction::Or: return 3; + case Instruction::Xor: return 4; + case Instruction::FAdd: return 5; + case Instruction::FMul: return 6; + // other op code + default: return 7; + } + } + }; + + // Utility class storing the last paired Operands per basic block + // It maps from (Value, Opcode) to (Value, Instruction). + // Mapped instruction is where these two value get paired. + class LastOperandPairsMap { + public: + // We don't want to walk through the map to figure out the dangling Value + // and instructions, so just use WeakVH here and check the pointer before + // use it. + typedef std::pair PairedValue; + void Clear() { PairingMap.clear(); } + void Add(Value *LHS, Value *RHS, Instruction *I); + void Remove(Instruction *I); + std::pair Lookup(Value *V, unsigned Opcode) const; + void Cleanup(Instruction *I); + + private: + DenseMap, IndexedMap> PairingMap; + }; +} + +namespace { class Reassociate : public FunctionPass { + const unsigned ConstantRank = (1 << 16); DenseMap RankMap; DenseMap, unsigned> ValueRankMap; SetVector > RedoInsts; bool MadeChange; + LastOperandPairsMap LastPairingMap; public: static char ID; // Pass identification, replacement for typeid Reassociate() : FunctionPass(ID) { @@ -226,6 +269,51 @@ isOr = true; } +void LastOperandPairsMap::Add(Value *LHS, Value *RHS, Instruction *I) { + auto Opcode = I->getOpcode(); + Value *Op[] = {LHS, RHS}; + for (unsigned i = 0; i != 2; ++i) { + auto &Map = PairingMap[Op[i]]; + Map.grow(Opcode); + Map[I->getOpcode()] = std::make_pair(WeakVH(Op[1 - i]), WeakVH(I)); + } +} + +void LastOperandPairsMap::Remove(Instruction *I) { + if (!isa(I)) + return; + + auto Opcode = I->getOpcode(); + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + auto Iter = PairingMap.find(I->getOperand(i)); + if (Iter != PairingMap.end() && Iter->second.inBounds(Opcode) && + Iter->second[Opcode].second == I) { + Iter->second[Opcode].first = nullptr; + Iter->second[Opcode].second = nullptr; + } + } +} + +std::pair LastOperandPairsMap::Lookup(Value *V, + unsigned Opcode) const +{ + auto Iter = PairingMap.find(V); + if (Iter == PairingMap.end()) { + return PairedValue(); + } + auto &Map = Iter->second; + if (!Map.inBounds(Opcode)) { + return PairedValue(); + } + + return Map[Opcode]; +} + +void LastOperandPairsMap::Cleanup(Instruction *I) { + Remove(I); + PairingMap.erase(I); +} + char Reassociate::ID = 0; INITIALIZE_PASS(Reassociate, "reassociate", "Reassociate expressions", false, false) @@ -277,14 +365,16 @@ } void Reassociate::BuildRankMap(Function &F) { - unsigned i = 2; + unsigned i = ConstantRank + 2; - // Assign distinct ranks to function arguments. + // Assign distinct ranks to function arguments whose rank will be larger than + // constants. for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { ValueRankMap[&*I] = ++i; DEBUG(dbgs() << "Calculated Rank[" << I->getName() << "] = " << i << "\n"); } + i = 1; ReversePostOrderTraversal RPOT(&F); for (ReversePostOrderTraversal::rpo_iterator I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { @@ -304,7 +394,7 @@ Instruction *I = dyn_cast(V); if (!I) { if (isa(V)) return ValueRankMap[V]; // Function argument. - return 0; // Otherwise it's a global or constant, rank 0. + return ConstantRank; // Otherwise it's a global or constant. } if (unsigned Rank = ValueRankMap[I]) @@ -564,7 +654,8 @@ /// type and thus make the expression bigger. static bool LinearizeExprTree(BinaryOperator *I, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops, + LastOperandPairsMap &OpndPairsMap) { DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); @@ -610,6 +701,8 @@ while (!Worklist.empty()) { std::pair P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. + // The pairs introduced by inner nodes should be ignored. + OpndPairsMap.Remove(I); for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); @@ -787,6 +880,7 @@ Value *NewRHS = Ops[i+1].Op; Value *OldLHS = Op->getOperand(0); Value *OldRHS = Op->getOperand(1); + LastPairingMap.Add(NewLHS, NewRHS, Op); if (NewLHS == OldLHS && NewRHS == OldRHS) // Nothing changed, leave it alone. @@ -843,6 +937,7 @@ Op->setOperand(1, NewRHS); ExpressionChanged = Op; } + LastPairingMap.Add(Op->getOperand(0), NewRHS, Op); DEBUG(dbgs() << "TO: " << *Op << '\n'); MadeChange = true; ++NumChanged; @@ -1110,7 +1205,7 @@ return nullptr; SmallVector Tree; - MadeChange |= LinearizeExprTree(BO, Tree); + MadeChange |= LinearizeExprTree(BO, Tree, LastPairingMap); SmallVector Factors; Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -1898,7 +1993,7 @@ if (Cst && Cst != ConstantExpr::getBinOpIdentity(Opcode, I->getType())) { if (Cst == ConstantExpr::getBinOpAbsorber(Opcode, I->getType())) return Cst; - Ops.push_back(ValueEntry(0, Cst)); + Ops.push_back(ValueEntry(ConstantRank, Cst)); } if (Ops.size() == 1) return Ops[0].Op; @@ -1944,6 +2039,8 @@ // Erase the dead instruction. ValueRankMap.erase(I); RedoInsts.remove(I); + LastPairingMap.Cleanup(I); + I->eraseFromParent(); // Optimize its operands. SmallPtrSet Visited; // Detect self-referential nodes. @@ -2147,11 +2244,88 @@ ReassociateExpression(BO); } +static +std::pair FindPairedOperandIdx( + const LastOperandPairsMap& LastPairingMap, + unsigned Opcode, + const SmallVectorImpl &Ops, + unsigned Idx, + const std::unordered_multimap &Leaves) { + Value* PairedValue; + Value* Inst; + std::tie(PairedValue, Inst) = LastPairingMap.Lookup(Ops[Idx].Op, Opcode); + // Make sure the pair we found is valid, the pair may be invalid when: + // 1. The pair doesn't exist. + // 2. the instruction contains this pair is erased. + // 3. the value in this pair is erased. + if (!PairedValue || !Inst) { + return std::make_pair(false, 0U); + } + + // Search all the paired value appeared in Ops + auto IterPair = Leaves.equal_range(PairedValue); + // Find a Value which is not itself. A value can be paired with itself. + // For example, + // j = i * i + // ... + // k = a * i + // l = k * i + // will introduce a pair (i, i), but if we want to pair them together later, + // we need to make sure that x pairs with another x but not itself. + auto Iter = std::find_if( + IterPair.first, IterPair.second, + [Idx](const std::pair &Item) { + // check if we don't pair to currrent Op itself + return Item.second != Idx; + }); + + if (Iter == IterPair.second) { + return std::make_pair(false, 0U); + } + + return std::make_pair(true, Iter->second); +} + +static void PairOperands(const LastOperandPairsMap& LastPairingMap, + unsigned Opcode, + SmallVectorImpl &Ops) +{ + SmallPtrSet Visited; + unsigned PairedRank = 0; + // First put Value and its original position in Ops in to Leaves + std::unordered_multimap Leaves; + for (unsigned i = 0, e = Ops.size(); i != e ; ++i) { + Leaves.insert(std::make_pair(Ops[i].Op, i)); + } + for (unsigned i = 0, e = Ops.size(); i != e ; ++i) { + // Find a paired operand from Ops + bool PairFound; + unsigned PairedOpIdx; + std::tie(PairFound, PairedOpIdx) = FindPairedOperandIdx(LastPairingMap, + Opcode, Ops, + i, Leaves); + if (!PairFound) { + continue; + } + + DEBUG(dbgs() << "PAIR OPERANDS: [" << *Ops[i].Op << "], [" + << *Ops[PairedOpIdx].Op << "]\n"); + // Assign paired operands low rank + // TODO: RewriteExprTree can only guarantee one pair is put together, if + // situation changed, we change more pairs rank and will need to check if + // an Op is not paired before. + Ops[i].Rank = PairedRank; + Ops[PairedOpIdx].Rank = PairedRank++; + std::stable_sort(Ops.begin(), Ops.end()); + return; + } +} + void Reassociate::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector Tree; - MadeChange |= LinearizeExprTree(I, Tree); + MadeChange |= LinearizeExprTree(I, Tree, LastPairingMap); SmallVector Ops; Ops.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -2187,6 +2361,10 @@ return; } + // Make pair before OptimizeExpression may mess up constant merge and + // factorize in OptimizeExpression, so we do pair after that. + PairOperands(LastPairingMap, I->getOpcode(), Ops); + // We want to sink immediates as deeply as possible except in the case where // this is a multiply tree used only by an add, and the immediate is a -1. // In this case we reassociate to put the negation on the outside so that we @@ -2256,6 +2434,9 @@ else OptimizeInst(I); } + + // Clear LastPairingMap per basic block + LastPairingMap.Clear(); } // We are done with the rank map. Index: test/Transforms/Reassociate/pair.ll =================================================================== --- /dev/null +++ test/Transforms/Reassociate/pair.ll @@ -0,0 +1,35 @@ +; RUN: opt < %s -reassociate -early-cse -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +declare void @foo(i32) + +; foo(a + c); +; foo((a + (b + c)); +; => +; t = a + c; +; foo(t); +; foo(t + b); +define void @test1(i32 %a, i32 %b, i32 %c) { +; CHECK-LABEL: @test1( + %1 = mul i32 %a, %c +; CHECK: [[BASE:%[a-zA-Z0-9]+]] = mul i32 %c, %a + call void @foo(i32 %1) + %2 = mul i32 %b, %c + %3 = mul i32 %a, %2 +; CHECK: mul i32 [[BASE]], %b + call void @foo(i32 %3) + ret void +} + +; Test that we will not pair inner nodes that are processed during linearization +define i32 @test2(i32 %a, i32 %b, i32 %z) { +; CHECK-LABEL: test2 + + %d = mul i32 %z, 40 + %c = sub i32 0, %d + %e = mul i32 %a, %c +; CHECK-NOT: %e = mul i32 %z, 40 + %f = sub i32 0, %e + ret i32 %f +}