Index: lib/Transforms/Scalar/Reassociate.cpp =================================================================== --- lib/Transforms/Scalar/Reassociate.cpp +++ lib/Transforms/Scalar/Reassociate.cpp @@ -22,6 +22,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/IndexedMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -39,6 +40,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" #include +#include using namespace llvm; #define DEBUG_TYPE "reassociate" @@ -157,11 +159,49 @@ } namespace { + struct OpcodeToIndex : public std::unary_function { + unsigned operator()(unsigned op) const { + switch(op) { + case Instruction::Add: return 0; + case Instruction::Mul: return 1; + case Instruction::And: return 2; + case Instruction::Or: return 3; + case Instruction::Xor: return 4; + case Instruction::FAdd: return 5; + case Instruction::FMul: return 6; + // other op code + default: return 7; + } + } + }; + + // Utility class storing the last paired Operands per basic block + // It maps from (Value, Opcode) to (Value, Instruction). + // Mapped instruction is where these two value get paired. + class LastOperandPairsMap { + public: + // We don't want to walk through the map to figure out the dangling Value + // and instructions, so just use WeakVH here and check the pointer before + // use it. + typedef std::pair PairedValue; + void Clear() { PairingMap.clear(); } + void Add(Value *LHS, Value *RHS, Instruction *I); + void Remove(Instruction *I); + std::pair Lookup(Value *V, unsigned Opcode) const; + void Cleanup(Instruction *I); + + private: + DenseMap, IndexedMap> PairingMap; + }; +} + +namespace { class Reassociate : public FunctionPass { DenseMap RankMap; DenseMap, unsigned> ValueRankMap; SetVector > RedoInsts; bool MadeChange; + LastOperandPairsMap LastPairingMap; public: static char ID; // Pass identification, replacement for typeid Reassociate() : FunctionPass(ID) { @@ -226,6 +266,51 @@ isOr = true; } +void LastOperandPairsMap::Add(Value *LHS, Value *RHS, Instruction *I) { + auto Opcode = I->getOpcode(); + Value *Op[] = {LHS, RHS}; + for (unsigned i = 0; i != 2; ++i) { + auto &Map = PairingMap[Op[i]]; + Map.grow(Opcode); + Map[I->getOpcode()] = std::make_pair(WeakVH(Op[1 - i]), WeakVH(I)); + } +} + +void LastOperandPairsMap::Remove(Instruction *I) { + if (!isa(I)) + return; + + auto Opcode = I->getOpcode(); + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + auto Iter = PairingMap.find(I->getOperand(i)); + if (Iter != PairingMap.end() && Iter->second.inBounds(Opcode) && + Iter->second[Opcode].second == I) { + Iter->second[Opcode].first = nullptr; + Iter->second[Opcode].second = nullptr; + } + } +} + +std::pair LastOperandPairsMap::Lookup(Value *V, + unsigned Opcode) const +{ + auto Iter = PairingMap.find(V); + if (Iter == PairingMap.end()) { + return PairedValue(); + } + auto &Map = Iter->second; + if (!Map.inBounds(Opcode)) { + return PairedValue(); + } + + return Map[Opcode]; +} + +void LastOperandPairsMap::Cleanup(Instruction *I) { + Remove(I); + PairingMap.erase(I); +} + char Reassociate::ID = 0; INITIALIZE_PASS(Reassociate, "reassociate", "Reassociate expressions", false, false) @@ -564,7 +649,8 @@ /// type and thus make the expression bigger. static bool LinearizeExprTree(BinaryOperator *I, - SmallVectorImpl &Ops) { + SmallVectorImpl &Ops, + LastOperandPairsMap &OpndPairsMap) { DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); unsigned Opcode = I->getOpcode(); @@ -610,6 +696,8 @@ while (!Worklist.empty()) { std::pair P = Worklist.pop_back_val(); I = P.first; // We examine the operands of this binary operator. + // The pairs introduced by inner nodes should be ignored. + OpndPairsMap.Remove(I); for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands. Value *Op = I->getOperand(OpIdx); @@ -787,6 +875,7 @@ Value *NewRHS = Ops[i+1].Op; Value *OldLHS = Op->getOperand(0); Value *OldRHS = Op->getOperand(1); + LastPairingMap.Add(NewLHS, NewRHS, Op); if (NewLHS == OldLHS && NewRHS == OldRHS) // Nothing changed, leave it alone. @@ -843,6 +932,7 @@ Op->setOperand(1, NewRHS); ExpressionChanged = Op; } + LastPairingMap.Add(Op->getOperand(0), NewRHS, Op); DEBUG(dbgs() << "TO: " << *Op << '\n'); MadeChange = true; ++NumChanged; @@ -1110,7 +1200,7 @@ return nullptr; SmallVector Tree; - MadeChange |= LinearizeExprTree(BO, Tree); + MadeChange |= LinearizeExprTree(BO, Tree, LastPairingMap); SmallVector Factors; Factors.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -1944,6 +2034,8 @@ // Erase the dead instruction. ValueRankMap.erase(I); RedoInsts.remove(I); + LastPairingMap.Cleanup(I); + I->eraseFromParent(); // Optimize its operands. SmallPtrSet Visited; // Detect self-referential nodes. @@ -2147,11 +2239,113 @@ ReassociateExpression(BO); } +static void MoveOperandsToEnd(unsigned OpIdxA, unsigned OpIdxB, + SmallVectorImpl &Ops) +{ + using std::swap; + // If the operands are already in their place, do nothing + if ((OpIdxA + 2 == Ops.size() && OpIdxB + 1 == Ops.size()) || + (OpIdxA + 1 == Ops.size() && OpIdxB + 2 == Ops.size())) { + return; + } + + // First move two Operands to the back of Ops + std::move(Ops.begin() + OpIdxA, Ops.begin() + OpIdxA + 1, std::back_inserter(Ops)); + std::move(Ops.begin() + OpIdxB, Ops.begin() + OpIdxB + 1, std::back_inserter(Ops)); + + // make sure i < j + if (OpIdxA > OpIdxB) + swap(OpIdxA, OpIdxB); + // Erase the old element at OpIdxA and OpIdxB, because Ops is stored in a + // use std::remove_if instead of calling erase twice to avoid redundant copy. + unsigned count = 0; + auto Delta = OpIdxB - OpIdxA; + Ops.erase(std::remove_if(Ops.begin() + OpIdxA, Ops.end(), + [&count, Delta](const ValueEntry&) -> bool { + auto result = (count == 0 || count == Delta); + count ++; + return result; + }), + Ops.end()); +} + +static +std::pair FindPairedOperandIdx( + const LastOperandPairsMap& LastPairingMap, + unsigned Opcode, + const SmallVectorImpl &Ops, + unsigned Idx, + const std::unordered_multimap &Leaves) { + Value* PairedValue; + Value* Inst; + std::tie(PairedValue, Inst) = LastPairingMap.Lookup(Ops[Idx].Op, Opcode); + // Make sure the pair we found is valid, the pair may be invalid when: + // 1. The pair doesn't exist. + // 2. the instruction contains this pair is erased. + // 3. the value in this pair is erased. + if (!PairedValue || !Inst) { + return std::make_pair(false, 0U); + } + + // Search all the paired value appeared in Ops + auto IterPair = Leaves.equal_range(PairedValue); + // Find a Value which is not itself. A value can be paired with itself. + // For example, + // j = i * i + // ... + // k = a * i + // l = k * i + // will introduce a pair (i, i), but if we want to pair them together later, + // we need to make sure that x pairs with another x but not itself. + auto Iter = std::find_if( + IterPair.first, IterPair.second, + [Idx](const std::pair &Item) { + // check if we don't pair to currrent Op itself + return Item.second != Idx; + }); + + if (Iter == IterPair.second) { + return std::make_pair(false, 0U); + } + + return std::make_pair(true, Iter->second); +} + +static void PairOperands(const LastOperandPairsMap& LastPairingMap, + unsigned Opcode, + SmallVectorImpl &Ops) +{ + SmallPtrSet Visited; + // First put Value and its original position in Ops in to Leaves + std::unordered_multimap Leaves; + for (unsigned i = 0, e = Ops.size(); i != e ; ++i) { + Leaves.insert(std::make_pair(Ops[i].Op, i)); + } + for (unsigned i = 0, e = Ops.size(); i != e ; ++i) { + // Find a paired operand from Ops + bool PairFound; + unsigned PairedOpIdx; + std::tie(PairFound, PairedOpIdx) = FindPairedOperandIdx(LastPairingMap, + Opcode, Ops, + i, Leaves); + if (!PairFound) { + continue; + } + + DEBUG(dbgs() << "PAIR OPERANDS: [" << *Ops[i].Op << "], [" + << *Ops[PairedOpIdx].Op << "]\n"); + // Move paired operands to the end of Ops in order to make RewriteExprTree + // pair them together. + MoveOperandsToEnd(i, PairedOpIdx, Ops); + return; + } +} + void Reassociate::ReassociateExpression(BinaryOperator *I) { // First, walk the expression tree, linearizing the tree, collecting the // operand information. SmallVector Tree; - MadeChange |= LinearizeExprTree(I, Tree); + MadeChange |= LinearizeExprTree(I, Tree, LastPairingMap); SmallVector Ops; Ops.reserve(Tree.size()); for (unsigned i = 0, e = Tree.size(); i != e; ++i) { @@ -2187,6 +2381,10 @@ return; } + // Make pair before OptimizeExpression may mess up constant merge and + // factorize in OptimizeExpression, so we do pair after that. + PairOperands(LastPairingMap, I->getOpcode(), Ops); + // We want to sink immediates as deeply as possible except in the case where // this is a multiply tree used only by an add, and the immediate is a -1. // In this case we reassociate to put the negation on the outside so that we @@ -2256,6 +2454,9 @@ else OptimizeInst(I); } + + // Clear LastPairingMap per basic block + LastPairingMap.Clear(); } // We are done with the rank map. Index: test/Transforms/Reassociate/pair.ll =================================================================== --- /dev/null +++ test/Transforms/Reassociate/pair.ll @@ -0,0 +1,35 @@ +; RUN: opt < %s -reassociate -early-cse -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" + +declare void @foo(i32) + +; foo(a + c); +; foo((a + (b + c)); +; => +; t = a + c; +; foo(t); +; foo(t + b); +define void @test1(i32 %a, i32 %b, i32 %c) { +; CHECK-LABEL: @test1( + %1 = mul i32 %a, %c +; CHECK: [[BASE:%[a-zA-Z0-9]+]] = mul i32 %c, %a + call void @foo(i32 %1) + %2 = mul i32 %b, %c + %3 = mul i32 %a, %2 +; CHECK: mul i32 [[BASE]], %b + call void @foo(i32 %3) + ret void +} + +; Test that we will not pair inner nodes that are processed during linearization +define i32 @test2(i32 %a, i32 %b, i32 %z) { +; CHECK-LABEL: test2 + + %d = mul i32 %z, 40 + %c = sub i32 0, %d + %e = mul i32 %a, %c +; CHECK-NOT: %e = mul i32 %z, 40 + %f = sub i32 0, %e + ret i32 %f +}