Index: include/llvm/Transforms/Scalar/SCCP.h =================================================================== --- include/llvm/Transforms/Scalar/SCCP.h +++ include/llvm/Transforms/Scalar/SCCP.h @@ -20,15 +20,15 @@ #ifndef LLVM_TRANSFORMS_SCALAR_SCCP_H #define LLVM_TRANSFORMS_SCALAR_SCCP_H - +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" namespace llvm { - class Function; /// This pass performs function-level constant propagation and merging. @@ -37,7 +37,9 @@ PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; -bool runIPSCCP(Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI); +bool runIPSCCP( + Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI, + function_ref(Function &)> getPredicateInfo); } // end namespace llvm #endif // LLVM_TRANSFORMS_SCALAR_SCCP_H Index: lib/Transforms/IPO/SCCP.cpp =================================================================== --- lib/Transforms/IPO/SCCP.cpp +++ lib/Transforms/IPO/SCCP.cpp @@ -1,14 +1,23 @@ #include "llvm/Transforms/IPO/SCCP.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar/SCCP.h" - using namespace llvm; PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { const DataLayout &DL = M.getDataLayout(); auto &TLI = AM.getResult(M); - if (!runIPSCCP(M, DL, &TLI)) + auto &FAM = AM.getResult(M).getManager(); + auto getPredicateInfo = + [&FAM](Function &F) -> std::unique_ptr { + return make_unique(F, + FAM.getResult(F), + FAM.getResult(F)); + }; + + if (!runIPSCCP(M, DL, &TLI, getPredicateInfo)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -34,10 +43,20 @@ const DataLayout &DL = M.getDataLayout(); const TargetLibraryInfo *TLI = &getAnalysis().getTLI(); - return runIPSCCP(M, DL, TLI); + + auto getPredicateInfo = + [this](Function &F) -> std::unique_ptr { + return make_unique( + F, this->getAnalysis(F).getDomTree(), + this->getAnalysis().getAssumptionCache(F)); + }; + + return runIPSCCP(M, DL, TLI, getPredicateInfo); } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); AU.addRequired(); } }; @@ -49,6 +68,7 @@ INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", "Interprocedural Sparse Conditional Constant Propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", "Interprocedural Sparse Conditional Constant Propagation", Index: lib/Transforms/Scalar/SCCP.cpp =================================================================== --- lib/Transforms/Scalar/SCCP.cpp +++ lib/Transforms/Scalar/SCCP.cpp @@ -55,6 +55,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" #include #include #include @@ -248,7 +249,21 @@ using Edge = std::pair; DenseSet KnownFeasibleEdges; + DenseMap> PredInfos; + DenseMap> AdditionalUsers; + public: + void addPredInfo(Function &F, std::unique_ptr PI) { + PredInfos[&F] = std::move(PI); + } + + const PredicateBase *getPredicateInfoFor(Instruction *I) { + auto PI = PredInfos.find(I->getParent()->getParent()); + if (PI == PredInfos.end()) + return nullptr; + return PI->second->getPredicateInfoFor(I); + } + SCCPSolver(const DataLayout &DL, const TargetLibraryInfo *tli) : DL(DL), TLI(tli) {} @@ -1158,6 +1173,63 @@ Function *F = CS.getCalledFunction(); Instruction *I = CS.getInstruction(); + if (auto *II = dyn_cast(I)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + LatticeVal &IV = ValueState[I]; + if (IV.isOverdefined()) + return; + + auto *PI = getPredicateInfoFor(I); + if (!PI) + return; + + auto *PBranch = dyn_cast(getPredicateInfoFor(I)); + if (!PBranch) + return mergeInValue(IV, I, getValueState(PI->OriginalOp)); + + Value *CopyOf = I->getOperand(0); + Value *Cond = PBranch->Condition; + + // Everything below relies on the condition being a comparison. + auto *Cmp = dyn_cast(Cond); + if (!Cmp) + return mergeInValue(IV, I, getValueState(PI->OriginalOp)); + + Value *CmpOp0 = Cmp->getOperand(0); + Value *CmpOp1 = Cmp->getOperand(1); + if (CopyOf != CmpOp0 && CopyOf != CmpOp1) { + // DEBUG(dbgs() << "Copy is not of any condition operands!\n"); + return mergeInValue(IV, I, getValueState(PI->OriginalOp)); + } + + if (CmpOp0 != CopyOf) + std::swap(CmpOp0, CmpOp1); + + LatticeVal &OriginalVal = getValueState(CopyOf); + LatticeVal &EqVal = getValueState(CmpOp1); + if (PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_EQ) { + auto Iter = AdditionalUsers.insert({CmpOp1, {}}); + Iter.first->second.insert(I); + if (OriginalVal.isConstant()) + mergeInValue(IV, I, OriginalVal); + else + mergeInValue(IV, I, EqVal); + return; + } + if (!PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_NE) { + auto Iter = AdditionalUsers.insert({CmpOp1, {}}); + Iter.first->second.insert(I); + if (OriginalVal.isConstant()) + mergeInValue(IV, I, OriginalVal); + else + mergeInValue(IV, I, EqVal); + return; + } + + return mergeInValue(IV, I, getValueState(PBranch->OriginalOp)); + } + } + // The common case is that we aren't tracking the callee, either because we // are not doing interprocedural analysis or the callee is indirect, or is // external. Handle these cases first. @@ -1272,6 +1344,12 @@ for (User *U : I->users()) if (auto *UI = dyn_cast(U)) OperandChangedState(UI); + auto Iter = AdditionalUsers.find(I); + if (Iter != AdditionalUsers.end()) { + for (User *U : Iter->second) + if (auto *UI = dyn_cast(U)) + OperandChangedState(UI); + } } // Process the instruction work list. @@ -1291,6 +1369,12 @@ for (User *U : I->users()) if (auto *UI = dyn_cast(U)) OperandChangedState(UI); + auto Iter = AdditionalUsers.find(I); + if (Iter != AdditionalUsers.end()) { + for (User *U : Iter->second) + if (auto *UI = dyn_cast(U)) + OperandChangedState(UI); + } } // Process the basic block work list. @@ -1856,8 +1940,9 @@ } } -bool llvm::runIPSCCP(Module &M, const DataLayout &DL, - const TargetLibraryInfo *TLI) { +bool llvm::runIPSCCP( + Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI, + function_ref(Function &)> getPredicateInfo) { SCCPSolver Solver(DL, TLI); // Loop over all functions, marking arguments to those with their addresses @@ -1866,6 +1951,7 @@ if (F.isDeclaration()) continue; + Solver.addPredInfo(F, getPredicateInfo(F)); // Determine if we can track the function's return values. If so, add the // function to the solver's set of return-tracked functions. if (canTrackReturnsInterprocedurally(&F)) @@ -1984,6 +2070,24 @@ F.getBasicBlockList().erase(DeadBB); } BlocksToErase.clear(); + + for (BasicBlock &BB : F) { + for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { + Instruction *Inst = &*BI++; + if (const PredicateBase *PI = Solver.getPredicateInfoFor(Inst)) { + if (auto *II = dyn_cast(Inst)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + Value *Op = II->getOperand(0); + Inst->replaceAllUsesWith(Op); + Inst->eraseFromParent(); + continue; + } + } + Inst->replaceAllUsesWith(PI->OriginalOp); + Inst->eraseFromParent(); + } + } + } } // If we inferred constant or undef return values for a function, we replaced Index: test/Transforms/SCCP/ipsccp-predicated.ll =================================================================== --- /dev/null +++ test/Transforms/SCCP/ipsccp-predicated.ll @@ -0,0 +1,68 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -ipsccp -S | FileCheck %s + +define i32 @test1(i32 %v) { +; CHECK-LABEL: @test1( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[TOBOOL1:%.*]] = icmp eq i32 [[V:%.*]], 10 +; CHECK-NEXT: br i1 [[TOBOOL1]], label [[T:%.*]], label [[F:%.*]] +; CHECK: T: +; CHECK-NEXT: [[R:%.*]] = call i32 @callee(i32 20) +; CHECK-NEXT: ret i32 [[R]] +; CHECK: F: +; CHECK-NEXT: [[X:%.*]] = call i32 @callee(i32 [[V]]) +; CHECK-NEXT: ret i32 [[X]] +; +Entry: + %tobool1 = icmp eq i32 %v, 10 + br i1 %tobool1, label %T, label %F + +T: + %a = add i32 %v, 10 + %r = call i32 @callee(i32 %a) + ret i32 %r + +F: + %x = call i32 @callee(i32 %v) + ret i32 %x +} + + +define internal i32 @test2(i32 %v, i32 %c) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: Entry: +; CHECK-NEXT: [[TOBOOL1:%.*]] = icmp eq i32 [[V:%.*]], 99 +; CHECK-NEXT: br i1 [[TOBOOL1]], label [[T:%.*]], label [[F:%.*]] +; CHECK: T: +; CHECK-NEXT: [[R:%.*]] = call i32 @callee(i32 109) +; CHECK-NEXT: ret i32 [[R]] +; CHECK: F: +; CHECK-NEXT: [[X:%.*]] = call i32 @callee(i32 [[V]]) +; CHECK-NEXT: ret i32 [[X]] +; +Entry: + %tobool1 = icmp eq i32 %v, %c + br i1 %tobool1, label %T, label %F + +T: + %a = add i32 %v, 10 + %r = call i32 @callee(i32 %a) + ret i32 %r + +F: + %x = call i32 @callee(i32 %v) + ret i32 %x +} + +define i32 @caller_test2(i32 %v) { +; CHECK-LABEL: @caller_test2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[R:%.*]] = call i32 @test2(i32 [[V:%.*]], i32 99) +; CHECK-NEXT: ret i32 [[R]] +; +entry: + %r = call i32 @test2(i32 %v, i32 99) + ret i32 %r +} + +declare i32 @callee(i32)