diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -18,6 +18,7 @@ #include "llvm-c/Transforms/AggressiveInstCombine.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -394,10 +395,11 @@ /// This is the entry point for all transforms. Pass manager differences are /// handled in the callers of this function. -static bool runImpl(Function &F, TargetLibraryInfo &TLI, DominatorTree &DT) { +static bool runImpl(Function &F, AssumptionCache &AC, TargetLibraryInfo &TLI, + DominatorTree &DT) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); - TruncInstCombine TIC(TLI, DL, DT); + TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); MadeChange |= foldUnusualPatterns(F, DT); return MadeChange; @@ -406,6 +408,7 @@ void AggressiveInstCombinerLegacyPass::getAnalysisUsage( AnalysisUsage &AU) const { AU.setPreservesCFG(); + AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addPreserved(); @@ -415,16 +418,18 @@ } bool AggressiveInstCombinerLegacyPass::runOnFunction(Function &F) { + auto &AC = getAnalysis().getAssumptionCache(F); auto &TLI = getAnalysis().getTLI(F); auto &DT = getAnalysis().getDomTree(); - return runImpl(F, TLI, DT); + return runImpl(F, AC, TLI, DT); } PreservedAnalyses AggressiveInstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { + auto &AC = AM.getResult(F); auto &TLI = AM.getResult(F); auto &DT = AM.getResult(F); - if (!runImpl(F, TLI, DT)) { + if (!runImpl(F, AC, TLI, DT)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } @@ -438,6 +443,7 @@ INITIALIZE_PASS_BEGIN(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", "Combine pattern based expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(AggressiveInstCombinerLegacyPass, "aggressive-instcombine", diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -17,6 +17,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -39,16 +41,18 @@ //===----------------------------------------------------------------------===// namespace llvm { - class DataLayout; - class DominatorTree; - class Function; - class Instruction; - class TargetLibraryInfo; - class TruncInst; - class Type; - class Value; +class AssumptionCache; +class DataLayout; +class DominatorTree; +class Function; +class Instruction; +class TargetLibraryInfo; +class TruncInst; +class Type; +class Value; class TruncInstCombine { + AssumptionCache ∾ TargetLibraryInfo &TLI; const DataLayout &DL; const DominatorTree &DT; @@ -75,9 +79,9 @@ MapVector InstInfoMap; public: - TruncInstCombine(TargetLibraryInfo &TLI, const DataLayout &DL, - const DominatorTree &DT) - : TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} + TruncInstCombine(AssumptionCache &AC, TargetLibraryInfo &TLI, + const DataLayout &DL, const DominatorTree &DT) + : AC(AC), TLI(TLI), DL(DL), DT(DT), CurrentTruncInst(nullptr) {} /// Perform TruncInst pattern optimization on given function. bool run(Function &F); @@ -104,6 +108,18 @@ /// to be reduced. Type *getBestTruncatedType(); + KnownBits computeKnownBits(const Value *V) const { + return llvm::computeKnownBits(V, DL, /*Depth=*/0, &AC, + /*CtxI=*/cast(CurrentTruncInst), + &DT); + } + + unsigned ComputeNumSignBits(const Value *V) const { + return llvm::ComputeNumSignBits( + V, DL, /*Depth=*/0, &AC, /*CtxI=*/cast(CurrentTruncInst), + &DT); + } + /// Given a \p V value and a \p SclTy scalar type return the generated reduced /// value of \p V based on the type \p SclTy. /// diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -29,7 +29,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" @@ -288,19 +287,19 @@ for (auto &Itr : InstInfoMap) { Instruction *I = Itr.first; if (I->isShift()) { - KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL); + KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); unsigned MinBitWidth = KnownRHS.getMaxValue() .uadd_sat(APInt(OrigBitWidth, 1)) .getLimitedValue(OrigBitWidth); if (MinBitWidth == OrigBitWidth) return nullptr; if (I->getOpcode() == Instruction::LShr) { - KnownBits KnownLHS = computeKnownBits(I->getOperand(0), DL); + KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); MinBitWidth = std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); } if (I->getOpcode() == Instruction::AShr) { - unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0), DL); + unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); } if (MinBitWidth >= OrigBitWidth) diff --git a/llvm/test/Transforms/AggressiveInstCombine/trunc_assume.ll b/llvm/test/Transforms/AggressiveInstCombine/trunc_assume.ll --- a/llvm/test/Transforms/AggressiveInstCombine/trunc_assume.ll +++ b/llvm/test/Transforms/AggressiveInstCombine/trunc_assume.ll @@ -5,11 +5,8 @@ ; CHECK-LABEL: @trunc_shl( ; CHECK-NEXT: [[CMP0:%.*]] = icmp ult i16 [[Y:%.*]], 16 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP0]]) -; CHECK-NEXT: [[ZEXTX:%.*]] = zext i16 [[X:%.*]] to i32 -; CHECK-NEXT: [[ZEXTY:%.*]] = zext i16 [[Y]] to i32 -; CHECK-NEXT: [[I0:%.*]] = shl i32 [[ZEXTX]], [[ZEXTY]] -; CHECK-NEXT: [[R:%.*]] = trunc i32 [[I0]] to i16 -; CHECK-NEXT: ret i16 [[R]] +; CHECK-NEXT: [[I0:%.*]] = shl i16 [[X:%.*]], [[Y]] +; CHECK-NEXT: ret i16 [[I0]] ; %cmp0 = icmp ult i16 %y, 16 call void @llvm.assume(i1 %cmp0) @@ -28,11 +25,8 @@ ; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i16 [[Y:%.*]], 16 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP0]]) ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP1]]) -; CHECK-NEXT: [[ZEXTX:%.*]] = zext i16 [[X]] to i32 -; CHECK-NEXT: [[ZEXTY:%.*]] = zext i16 [[Y]] to i32 -; CHECK-NEXT: [[I0:%.*]] = lshr i32 [[ZEXTX]], [[ZEXTY]] -; CHECK-NEXT: [[R:%.*]] = trunc i32 [[I0]] to i16 -; CHECK-NEXT: ret i16 [[R]] +; CHECK-NEXT: [[I0:%.*]] = lshr i16 [[X]], [[Y]] +; CHECK-NEXT: ret i16 [[I0]] ; %cmp0 = icmp ult i16 %x, 65536 %cmp1 = icmp ult i16 %y, 16 @@ -55,11 +49,8 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP0]]) ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP1]]) ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP2]]) -; CHECK-NEXT: [[ZEXTX:%.*]] = sext i16 [[X]] to i32 -; CHECK-NEXT: [[ZEXTY:%.*]] = sext i16 [[Y]] to i32 -; CHECK-NEXT: [[I0:%.*]] = ashr i32 [[ZEXTX]], [[ZEXTY]] -; CHECK-NEXT: [[R:%.*]] = trunc i32 [[I0]] to i16 -; CHECK-NEXT: ret i16 [[R]] +; CHECK-NEXT: [[I0:%.*]] = ashr i16 [[X]], [[Y]] +; CHECK-NEXT: ret i16 [[I0]] ; %cmp0 = icmp slt i16 %x, 32767 %cmp1 = icmp sge i16 %x, -32768