Index: include/llvm/Transforms/Utils/Local.h =================================================================== --- include/llvm/Transforms/Utils/Local.h +++ include/llvm/Transforms/Utils/Local.h @@ -172,6 +172,7 @@ bool SimplifyInstructionsInBlock(BasicBlock *BB, const TargetLibraryInfo *TLI = nullptr); +bool EliminateRedundantMasks(BasicBlock &BB); //===----------------------------------------------------------------------===// // Control Flow Graph Restructuring. // Index: lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -165,6 +165,7 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { bool MadeChange = false; for (BasicBlock &BB : F) { + MadeChange = EliminateRedundantMasks(BB); // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; Index: lib/Transforms/Utils/Local.cpp =================================================================== --- lib/Transforms/Utils/Local.cpp +++ lib/Transforms/Utils/Local.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/TinyPtrVector.h" @@ -611,6 +612,228 @@ return MadeChange; } +/// Eliminate redundant masks +/// let the pattern be: +/// X1 = and(X, c1) +/// X2 = and(shift(X, c3), c2) +/// where c1, c2, c3 are constants +/// We try to infer if we can just reuse X1 to compute X2 +/// obtaining X2 = shift(X1, c2) +static cl::opt EliminateRedundantMasksSearchLimit( + "erm-search-limit", cl::desc("Maximum instructions to compare a mask to."), + cl::init(6), cl::Hidden); + +bool llvm::EliminateRedundantMasks(BasicBlock &BB) { + struct APIntLT { + bool operator()(APInt a, APInt b) { return a.ult(b); } + }; + + typedef std::map ANDS; + + bool HasChanges = false; + // Cache for every masked value the bits that we don't know to be zero + std::map ValuesNonZeroCache; + // A given mask should be extracted from a value once. Cache tem as to + // detect duplicated + std::map Ands; + // Duplicated masking operations are removed at the 2nd stage + SmallSet ToRemove; + + // 1st step: Cache all (and X, c1) values, using key (X, "effective c1") + // If we detect duplicated AND operations, erase those that come after. + for (Instruction &I : BB) { + const Type *T = I.getType(); + // At the moment we limit it to work in integer types, sized <= 64bits + if (!(T->isIntegerTy() && T->isSized()) || T->isVectorTy()) + continue; + + const APInt *MaskR; + BinaryOperator *AndOp; + Value *X; + + if (!match(&I, + m_CombineAnd(m_BinOp(AndOp), m_And(m_Value(X), m_APInt(MaskR))))) + continue; + + APInt Mask = *MaskR; + if (ValuesNonZeroCache.find(X) == ValuesNonZeroCache.end()) { + KnownBits KB = + computeKnownBits(X, BB.getParent()->getParent()->getDataLayout()); + APInt NotZero = ~KB.Zero; + + Mask &= NotZero; + ValuesNonZeroCache[X] = NotZero; + } + + if (BinaryOperator *ReferenceAnd = Ands[X][Mask]) { + // Replace all uses of this found and with the already cached one + LLVM_DEBUG(dbgs() << "Replacing: "; AndOp->dump(); dbgs() << " With: "; + ReferenceAnd->dump()); + AndOp->replaceAllUsesWith(ReferenceAnd); + ToRemove.insert(AndOp); + HasChanges = true; + continue; + } + + Ands[X][Mask] = AndOp; + LLVM_DEBUG(dbgs() << "Value: "; X->dump(); dbgs() << "\t is masked by "; + AndOp->dump(); + dbgs() << "\tThe mask: 0x" << Mask.toString(16, false) << '\n'); + } + + // 2nd stage: We remove the dead masking operations + for (Instruction *I : ToRemove) { + I->removeFromParent(); + I->deleteValue(); + } + + ToRemove.clear(); + + // 3rd stage: We check backwards if masking of shift operations also extract + // the same mask, replacing their operand for the existing mask. + + // Once we decided to reuse a given value, we must ensure all (and) ops + // dominate their uses. ToExecuteBefore holds the first user for an (and). + std::map ToExecuteBefore; + + for (auto II = BB.rbegin(); II != BB.rend(); II++) { + Instruction *I = &*II; + const Type *T = I->getType(); + if (!(T->isIntegerTy() && T->isSized()) || T->isVectorTy()) + continue; + + ConstantInt *ShiftAmtC; + const APInt *SMask; + BinaryOperator *Shift, *DeadAnd; + Value *X; + + if (!match(I, m_CombineAnd( + m_BinOp(DeadAnd), + m_And(m_CombineAnd( + m_BinOp(Shift), + m_OneUse(m_Shift(m_Value(X), + m_ConstantInt(ShiftAmtC)))), + m_APInt(SMask))))) + continue; + LLVM_DEBUG(dbgs() << "Matched: "; DeadAnd->dump(); Shift->dump();); + + auto ValueAndsI = Ands.find(X); + if (ValueAndsI == Ands.end()) + continue; + + const APInt *ShiftAmt = &ShiftAmtC->getValue(); + const unsigned VShiftAmt = ShiftAmt->getZExtValue(); + const bool SafeAShr = SMask->countLeadingOnes() < VShiftAmt; + const bool AShrToLShr = ValuesNonZeroCache[X].isSignBitClear() || + SMask->countLeadingZeros() >= VShiftAmt; + const auto Opcode = Shift->getOpcode(); + // We try to find an direct match of the masked value + BinaryOperator *RAnd = nullptr; + if (Opcode == Instruction::Shl) { + APInt EffectiveMask = ValuesNonZeroCache[X] & SMask->lshr(VShiftAmt); + LLVM_DEBUG(dbgs() << "\tThe mask: 0x" << EffectiveMask.toString(16, false) + << '\n'); + + RAnd = ValueAndsI->second[EffectiveMask]; + } else if (Opcode == Instruction::LShr || SafeAShr || AShrToLShr) { + APInt EffectiveMask = ValuesNonZeroCache[X] & SMask->shl(VShiftAmt); + LLVM_DEBUG(dbgs() << "\tThe mask: 0x" << EffectiveMask.toString(16, false) + << "\nNon Zero mask: 0x" + << ValuesNonZeroCache[X].toString(16, false) << '\n'); + + RAnd = ValueAndsI->second[EffectiveMask]; + } + + // If we can't match it, try to explore the existing masks to see if any of + // them suits our required bits. Limited to search up to + // n masks = erm-search-limit[default=4]. + if (!RAnd) { + unsigned C = EliminateRedundantMasksSearchLimit; + for (auto ToTest = ValueAndsI->second.begin(), + End = ValueAndsI->second.end(); + ToTest != End && C != 0; C++, ToTest++) { + if (Shift->getOpcode() == Instruction::Shl) { + if (ToTest->first.shl(VShiftAmt) == *SMask) { + RAnd = ToTest->second; + break; + } + } else if (Shift->getOpcode() == Instruction::LShr) { + if (ToTest->first.lshr(VShiftAmt) == *SMask) { + RAnd = ToTest->second; + break; + } + } + // Instruction::AShr + else if (ToTest->first.ashr(VShiftAmt) == *SMask || + (AShrToLShr && ToTest->first.lshr(VShiftAmt) == *SMask)) { + RAnd = ToTest->second; + break; + } + } + } + + if (!RAnd) + continue; + + LLVM_DEBUG(dbgs() << "Reusing result of: "; RAnd->dump(); + dbgs() << " To compute: "; Shift->dump(); + dbgs() << " Eliminating: "; DeadAnd->dump()); + HasChanges = true; + if (Opcode == Instruction::AShr && AShrToLShr) { + BinaryOperator *newShift = BinaryOperator::CreateLShr(RAnd, ShiftAmtC); + newShift->insertBefore(Shift); + Shift->replaceAllUsesWith(newShift); + Shift->removeFromParent(); + Shift->deleteValue(); + Shift = newShift; + } else + Shift->setOperand(0, RAnd); +#ifndef NDEBUG + KnownBits Before = + computeKnownBits(DeadAnd, BB.getParent()->getParent()->getDataLayout()); + KnownBits After = + computeKnownBits(Shift, BB.getParent()->getParent()->getDataLayout()); + (void)Before; + (void)After; + LLVM_DEBUG(dbgs() << ""; + + if (!(Before.Zero == After.Zero && Before.One == After.One)) { + dbgs() << "Before zero: 0x" << Before.Zero.toString(16, false) + << "\nAfter zero: 0x" << After.Zero.toString(16, false) + << "\nBefore one: 0x" << Before.One.toString(16, false) + << "\nAfter one: 0x" << After.One.toString(16, false); + BB.dump(); + errs() << "This transformation is invalid!"; + }); + assert(Before.Zero == After.Zero && Before.One == After.One); +#endif + DeadAnd->replaceAllUsesWith(Shift); + ValuesNonZeroCache.erase(DeadAnd); + Ands.erase(DeadAnd); + ToRemove.insert(DeadAnd); + ToExecuteBefore[RAnd] = Shift; + } + + for (Instruction *I : ToRemove) { + I->removeFromParent(); + I->deleteValue(); + } + + ToRemove.clear(); + + auto End = BB.end(); + for (auto &ProducerUser : ToExecuteBefore) { + auto Producer = ProducerUser.first->getIterator(); + auto User = ProducerUser.second->getIterator(); + while (++Producer != End && Producer != User) + ; + if (Producer == End) + ProducerUser.first->moveBefore(ProducerUser.second); + } + ToExecuteBefore.clear(); + return HasChanges; +} + //===----------------------------------------------------------------------===// // Control Flow Graph Restructuring. //