Index: lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -19,12 +19,15 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; - +using namespace PatternMatch; #define DEBUG_TYPE "aggressive-instcombine" namespace { @@ -53,6 +56,88 @@ }; } // namespace +/// This is a recursive helper for 'and X, 1' that walks through a chain of 'or' +/// instructions looking for shift ops of a common source value (first member of +/// the pair). The second member of the pair is a mask constant for all of the +/// bits that are being compared. So this: +/// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) +/// returns {X, 0x129} and those are the operands of an 'and' that is compared +/// to zero. +static bool matchMaskedCmpOp(Value *V, std::pair &Result) { + // Recurse through a chain of 'or' operands. + Value *Op0, *Op1; + if (match(V, m_Or(m_Value(Op0), m_Value(Op1)))) + return matchMaskedCmpOp(Op0, Result) && matchMaskedCmpOp(Op1, Result); + + // We need a shift-right or a bare value representing a compare of bit 0 of + // the original source operand. + Value *Candidate; + uint64_t BitIndex = 0; + if (!match(V, m_LShr(m_Value(Candidate), m_ConstantInt(BitIndex)))) + Candidate = V; + + // Initialize result source operand. + if (!Result.first) + Result.first = Candidate; + + // Fill in the mask bit derived from the shift constant. + Result.second |= (1 << BitIndex); + return Result.first == Candidate; +} + +/// Match an 'and' of a chain of or-shifted bits from a common source value into +/// a masked compare: +/// and (or (lshr X, C), ...), 1 --> (X & C') != 0 +static bool foldToMaskedCmp(Instruction &I) { + if (!match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One()))) + return false; + + std::pair + MaskOps(nullptr, APInt::getNullValue(I.getType()->getScalarSizeInBits())); + if (!matchMaskedCmpOp(cast(&I)->getOperand(0), MaskOps)) + return false; + + IRBuilder<> Builder(&I); + Value *Mask = Builder.CreateAnd(MaskOps.first, MaskOps.second); + Value *CmpZero = Builder.CreateIsNotNull(Mask); + Value *Zext = Builder.CreateZExt(CmpZero, I.getType()); + I.replaceAllUsesWith(Zext); + return true; +} + +/// This is the entry point for folds that could be implemented in regular +/// InstCombine, but they are separated because they are not expected to +/// occur frequently and/or have more than a constant-length pattern match. +static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { + bool MadeChange = false; + for (BasicBlock &BB : F) { + // Ignore unreachable basic blocks. + if (!DT.isReachableFromEntry(&BB)) + continue; + // Do not delete instructions under here and invalidate the iterator. + for (Instruction &I : BB) + MadeChange |= foldToMaskedCmp(I); + } + + // We're done with transforms, so remove dead instructions. + if (MadeChange) + for (BasicBlock &BB : F) + SimplifyInstructionsInBlock(&BB); + + return MadeChange; +} + +/// This is the entry point for all transforms. Pass manager differences are +/// handled in the callers of this function. +static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { + bool MadeChange = false; + const DataLayout &DL = F.getParent()->getDataLayout(); + TruncInstCombine TIC(TLI, DL, DT); + MadeChange |= TIC.run(F); + MadeChange |= foldUnusualPatterns(F, DT); + return MadeChange; +} + void AggressiveInstCombinerLegacyPass::getAnalysisUsage( AnalysisUsage &AU) const { AU.setPreservesCFG(); @@ -65,35 +150,19 @@ } bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { - auto &DT = getAnalysis().getDomTree(); auto &TLI = getAnalysis().getTLI(); - auto &DL = F.getParent()->getDataLayout(); - - bool MadeIRChange = false; - - // Handle TruncInst patterns - TruncInstCombine TIC(TLI, DL, DT); - MadeIRChange |= TIC.run(F); - - // TODO: add more patterns to handle... - - return MadeIRChange; + auto &DT = getAnalysis().getDomTree(); + return runImpl(F, TLI, DT); } PreservedAnalyses AggressiveInstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { - auto &DT = AM.getResult(F); auto &TLI = AM.getResult(F); - auto &DL = F.getParent()->getDataLayout(); - bool MadeIRChange = false; - - // Handle TruncInst patterns - TruncInstCombine TIC(TLI, DL, DT); - MadeIRChange |= TIC.run(F); - if (!MadeIRChange) + auto &DT = AM.getResult(F); + if (!runImpl(F, TLI, DT)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); - + } // Mark all the analyses that instcombine updates as preserved. PreservedAnalyses PA; PA.preserveSet(); Index: test/Transforms/AggressiveInstCombine/masked-cmp.ll =================================================================== --- test/Transforms/AggressiveInstCombine/masked-cmp.ll +++ test/Transforms/AggressiveInstCombine/masked-cmp.ll @@ -5,10 +5,10 @@ define i32 @two_bit_mask(i32 %x) { ; CHECK-LABEL: @two_bit_mask( -; CHECK-NEXT: [[S:%.*]] = lshr i32 [[X:%.*]], 3 -; CHECK-NEXT: [[O:%.*]] = or i32 [[S]], [[X]] -; CHECK-NEXT: [[R:%.*]] = and i32 [[O]], 1 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 9 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; CHECK-NEXT: ret i32 [[TMP3]] ; %s = lshr i32 %x, 3 %o = or i32 %s, %x @@ -18,14 +18,10 @@ define i32 @four_bit_mask(i32 %x) { ; CHECK-LABEL: @four_bit_mask( -; CHECK-NEXT: [[T1:%.*]] = lshr i32 [[X:%.*]], 3 -; CHECK-NEXT: [[T2:%.*]] = lshr i32 [[X]], 5 -; CHECK-NEXT: [[T3:%.*]] = lshr i32 [[X]], 8 -; CHECK-NEXT: [[O1:%.*]] = or i32 [[T1]], [[X]] -; CHECK-NEXT: [[O2:%.*]] = or i32 [[T2]], [[T3]] -; CHECK-NEXT: [[O3:%.*]] = or i32 [[O1]], [[O2]] -; CHECK-NEXT: [[R:%.*]] = and i32 [[O3]], 1 -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], 297 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i32 +; CHECK-NEXT: ret i32 [[TMP3]] ; %t1 = lshr i32 %x, 3 %t2 = lshr i32 %x, 5