Index: llvm/trunk/lib/Transforms/Scalar/MergeICmps.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/MergeICmps.cpp +++ llvm/trunk/lib/Transforms/Scalar/MergeICmps.cpp @@ -41,10 +41,6 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -53,6 +49,10 @@ #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include +#include +#include +#include using namespace llvm; @@ -73,71 +73,87 @@ // that is a constant offset from a base value, e.g. `a` or `o.c` in the example // at the top. struct BCEAtom { - BCEAtom() : GEP(nullptr), LoadI(nullptr), Offset() {} - - const Value *Base() const { return GEP ? GEP->getPointerOperand() : nullptr; } - + BCEAtom() = default; + BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset) + : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {} + + // We want to order BCEAtoms by (Base, Offset). However we cannot use + // the pointer values for Base because these are non-deterministic. + // To make sure that the sort order is stable, we first assign to each atom + // base value an index based on its order of appearance in the chain of + // comparisons. We call this index `BaseOrdering`. For example, for: + // b[3] == c[2] && a[1] == d[1] && b[4] == c[3] + // | block 1 | | block 2 | | block 3 | + // b gets assigned index 0 and a index 1, because b appears as LHS in block 1, + // which is before block 2. + // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable. bool operator<(const BCEAtom &O) const { - assert(Base() && "invalid atom"); - assert(O.Base() && "invalid atom"); - // Just ordering by (Base(), Offset) is sufficient. However because this - // means that the ordering will depend on the addresses of the base - // values, which are not reproducible from run to run. To guarantee - // stability, we use the names of the values if they exist; we sort by: - // (Base.getName(), Base(), Offset). - const int NameCmp = Base()->getName().compare(O.Base()->getName()); - if (NameCmp == 0) { - if (Base() == O.Base()) { - return Offset.slt(O.Offset); - } - return Base() < O.Base(); - } - return NameCmp < 0; + return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset); } - GetElementPtrInst *GEP; - LoadInst *LoadI; + GetElementPtrInst *GEP = nullptr; + LoadInst *LoadI = nullptr; + unsigned BaseId = 0; APInt Offset; }; +// A class that assigns increasing ids to values in the order in which they are +// seen. See comment in `BCEAtom::operator<()``. +class BaseIdentifier { +public: + // Returns the id for value `Base`, after assigning one if `Base` has not been + // seen before. + int getBaseId(const Value *Base) { + assert(Base && "invalid base"); + const auto Insertion = BaseToIndex.try_emplace(Base, Order); + if (Insertion.second) + ++Order; + return Insertion.first->second; + } + +private: + unsigned Order = 1; + DenseMap BaseToIndex; +}; + // If this value is a load from a constant offset w.r.t. a base address, and // there are no other users of the load or address, returns the base address and // the offset. -BCEAtom visitICmpLoadOperand(Value *const Val) { - BCEAtom Result; - if (auto *const LoadI = dyn_cast(Val)) { - LLVM_DEBUG(dbgs() << "load\n"); - if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - // Do not optimize atomic loads to non-atomic memcmp - if (!LoadI->isSimple()) { - LLVM_DEBUG(dbgs() << "volatile or atomic\n"); - return {}; - } - Value *const Addr = LoadI->getOperand(0); - if (auto *const GEP = dyn_cast(Addr)) { - LLVM_DEBUG(dbgs() << "GEP\n"); - if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - const auto &DL = GEP->getModule()->getDataLayout(); - if (!isDereferenceablePointer(GEP, DL)) { - LLVM_DEBUG(dbgs() << "not dereferenceable\n"); - // We need to make sure that we can do comparison in any order, so we - // require memory to be unconditionnally dereferencable. - return {}; - } - Result.Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); - if (GEP->accumulateConstantOffset(DL, Result.Offset)) { - Result.GEP = GEP; - Result.LoadI = LoadI; - } - } +BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { + auto *const LoadI = dyn_cast(Val); + if (!LoadI) + return {}; + LLVM_DEBUG(dbgs() << "load\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + // Do not optimize atomic loads to non-atomic memcmp + if (!LoadI->isSimple()) { + LLVM_DEBUG(dbgs() << "volatile or atomic\n"); + return {}; + } + Value *const Addr = LoadI->getOperand(0); + auto *const GEP = dyn_cast(Addr); + if (!GEP) + return {}; + LLVM_DEBUG(dbgs() << "GEP\n"); + if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + const auto &DL = GEP->getModule()->getDataLayout(); + if (!isDereferenceablePointer(GEP, DL)) { + LLVM_DEBUG(dbgs() << "not dereferenceable\n"); + // We need to make sure that we can do comparison in any order, so we + // require memory to be unconditionnally dereferencable. + return {}; } - return Result; + APInt Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); + if (!GEP->accumulateConstantOffset(DL, Offset)) + return {}; + return BCEAtom(GEP, LoadI, BaseId.getBaseId(GEP->getPointerOperand()), + Offset); } // A basic block with a comparison between two BCE atoms, e.g. `a == o.a` in the @@ -159,9 +175,7 @@ if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_); } - bool IsValid() const { - return Lhs_.Base() != nullptr && Rhs_.Base() != nullptr; - } + bool IsValid() const { return Lhs_.BaseId != 0 && Rhs_.BaseId != 0; } // Assert the block is consistent: If valid, it should also have // non-null members besides Lhs_ and Rhs_. @@ -287,7 +301,8 @@ // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. BCECmpBlock visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate) { + const ICmpInst::Predicate ExpectedPredicate, + BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -297,25 +312,27 @@ LLVM_DEBUG(dbgs() << "cmp has several uses\n"); return {}; } - if (CmpI->getPredicate() == ExpectedPredicate) { - LLVM_DEBUG(dbgs() << "cmp " - << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") - << "\n"); - auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0)); - if (!Lhs.Base()) return {}; - auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); - if (!Rhs.Base()) return {}; - const auto &DL = CmpI->getModule()->getDataLayout(); - return BCECmpBlock(std::move(Lhs), std::move(Rhs), - DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); - } - return {}; + if (CmpI->getPredicate() != ExpectedPredicate) + return {}; + LLVM_DEBUG(dbgs() << "cmp " + << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") + << "\n"); + auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId); + if (!Lhs.BaseId) + return {}; + auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId); + if (!Rhs.BaseId) + return {}; + const auto &DL = CmpI->getModule()->getDataLayout(); + return BCECmpBlock(std::move(Lhs), std::move(Rhs), + DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); } // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, - const BasicBlock *const PhiBlock) { + const BasicBlock *const PhiBlock, + BaseIdentifier &BaseId) { if (Block->empty()) return {}; auto *const BranchI = dyn_cast(Block->getTerminator()); if (!BranchI) return {}; @@ -328,7 +345,7 @@ auto *const CmpI = dyn_cast(Val); if (!CmpI) return {}; LLVM_DEBUG(dbgs() << "icmp\n"); - auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ); + auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ, BaseId); Result.CmpI = CmpI; Result.BranchI = BranchI; return Result; @@ -345,7 +362,8 @@ assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); auto Result = visitICmp( - CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE); + CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BaseId); Result.CmpI = CmpI; Result.BranchI = BranchI; return Result; @@ -357,9 +375,9 @@ BCECmpBlock &Comparison) { LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName() << "': Found cmp of " << Comparison.SizeBits() - << " bits between " << Comparison.Lhs().Base() << " + " + << " bits between " << Comparison.Lhs().BaseId << " + " << Comparison.Lhs().Offset << " and " - << Comparison.Rhs().Base() << " + " + << Comparison.Rhs().BaseId << " + " << Comparison.Rhs().Offset << "\n"); LLVM_DEBUG(dbgs() << "\n"); Comparisons.push_back(Comparison); @@ -382,8 +400,8 @@ private: static bool IsContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { - return First.Lhs().Base() == Second.Lhs().Base() && - First.Rhs().Base() == Second.Rhs().Base() && + return First.Lhs().BaseId == Second.Lhs().BaseId && + First.Rhs().BaseId == Second.Rhs().BaseId && First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; } @@ -407,11 +425,12 @@ assert(!Blocks.empty() && "a chain should have at least one block"); // Now look inside blocks to check for BCE comparisons. std::vector Comparisons; + BaseIdentifier BaseId; for (size_t BlockIdx = 0; BlockIdx < Blocks.size(); ++BlockIdx) { BasicBlock *const Block = Blocks[BlockIdx]; assert(Block && "invalid block"); BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), - Block, Phi.getParent()); + Block, Phi.getParent(), BaseId); Comparison.BB = Block; if (!Comparison.IsValid()) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); @@ -488,9 +507,10 @@ #endif // MERGEICMPS_DOT_ON // Reorder blocks by LHS. We can do that without changing the // semantics because we are only accessing dereferencable memory. - llvm::sort(Comparisons_, [](const BCECmpBlock &a, const BCECmpBlock &b) { - return a.Lhs() < b.Lhs(); - }); + llvm::sort(Comparisons_, + [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { + return LhsBlock.Lhs() < RhsBlock.Lhs(); + }); #ifdef MERGEICMPS_DOT_ON errs() << "AFTER REORDERING:\n\n"; dump();