diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -310,6 +310,7 @@ void initializeUniqueInternalLinkageNamesLegacyPassPass(PassRegistry &); void initializeNaryReassociateLegacyPassPass(PassRegistry&); void initializeNewGVNLegacyPassPass(PassRegistry&); +void initializeNullCheckEliminationPass(PassRegistry &); void initializeObjCARCAAWrapperPassPass(PassRegistry&); void initializeObjCARCAPElimPass(PassRegistry&); void initializeObjCARCContractPass(PassRegistry&); diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -318,6 +318,11 @@ // FunctionPass *createNewGVNPass(); +//===----------------------------------------------------------------------===// +// +// +FunctionPass *createNullCheckEliminationPass(); + //===----------------------------------------------------------------------===// // // DivRemPairs - Hoist/decompose integer division and remainder instructions. diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp --- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -378,6 +378,7 @@ MPM.add(createSpeculativeExecutionIfHasBranchDivergencePass()); MPM.add(createJumpThreadingPass()); // Thread jumps. + MPM.add(createNullCheckEliminationPass()); MPM.add(createCorrelatedValuePropagationPass()); // Propagate conditionals } MPM.add(createCFGSimplificationPass()); // Merge & remove BBs @@ -454,6 +455,7 @@ addExtensionsToPM(EP_Peephole, MPM); if (OptLevel > 1) { MPM.add(createJumpThreadingPass()); // Thread jumps + MPM.add(createNullCheckEliminationPass()); MPM.add(createCorrelatedValuePropagationPass()); MPM.add(createDeadStoreEliminationPass()); // Delete dead stores MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); @@ -761,6 +763,7 @@ // dead (or speculatable) control flows or more combining opportunities. MPM.add(createEarlyCSEPass()); MPM.add(createCorrelatedValuePropagationPass()); + MPM.add(createNullCheckEliminationPass()); MPM.add(createInstructionCombiningPass()); MPM.add(createLICMPass(LicmMssaOptCap, LicmMssaNoAccForPromotionCap)); MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget)); diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -55,6 +55,7 @@ MergedLoadStoreMotion.cpp NaryReassociate.cpp NewGVN.cpp + DereferenceNullCheckElimination.cpp PartiallyInlineLibCalls.cpp PlaceSafepoints.cpp Reassociate.cpp diff --git a/llvm/lib/Transforms/Scalar/DereferenceNullCheckElimination.cpp b/llvm/lib/Transforms/Scalar/DereferenceNullCheckElimination.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Scalar/DereferenceNullCheckElimination.cpp @@ -0,0 +1,219 @@ +//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "nullcheck-elimination" + +STATISTIC(NumCmpsRemoved, "Number of cmps simplified"); + +struct CmpOrDeref { + unsigned NumIn; + unsigned NumOut; + Instruction *I; + + CmpOrDeref(DomTreeNode *DTN, Instruction *I) + : NumIn(DTN->getDFSNumIn()), NumOut(DTN->getDFSNumOut()), I(I) {} +}; + +struct StackEntry { + unsigned NumIn; + unsigned NumOut; + const Value *Obj; + + StackEntry(unsigned NumIn, unsigned NumOut, const Value *Obj) + : NumIn(NumIn), NumOut(NumOut), Obj(Obj) {} +}; + +struct UnderlyingObjCache { + const DataLayout &DL; + DenseMap UnderlyingObjs; + + UnderlyingObjCache(const DataLayout &DL) : DL(DL) {} + + const Value *getUnderlyingObjFor(const Value *Ptr) { + auto I = UnderlyingObjs.insert({Ptr, nullptr}); + if (I.second) { + const Value *UO = GetUnderlyingObject(Ptr, DL); + I.first->second = UO; + } + + return I.first->second; + } +}; + +static bool eliminateNullChecks(Function &F, DominatorTree &DT) { + bool Changed = false; + DT.updateDFSNumbers(); + SmallVector WorkList; + + auto DL = F.getParent()->getDataLayout(); + UnderlyingObjCache ObjCache(DL); + for (BasicBlock &BB : F) { + if (!DT.getNode(&BB)) + continue; + auto *DTN = DT.getNode(&BB); + SmallPtrSet UOInBlock; + for (auto &I : BB) { + + CmpInst::Predicate Pred; + Value *Ptr; + if (match(&I, m_ICmp(Pred, m_Value(Ptr), m_Zero())) || + match(&I, m_ICmp(Pred, m_Zero(), m_Value(Ptr)))) { + if (!I.getOperand(0)->getType()->isPointerTy()) + continue; + if (Pred != CmpInst::ICMP_EQ && Pred != CmpInst::ICMP_NE) + continue; + + if (!ObjCache.getUnderlyingObjFor(Ptr)) + continue; + + WorkList.emplace_back(DTN, &I); + continue; + } + + auto MaybeMemLoc = MemoryLocation::getOrNone(&I); + if (!MaybeMemLoc) + continue; + + auto *UO = ObjCache.getUnderlyingObjFor(MaybeMemLoc->Ptr); + if (!UOInBlock.insert(UO).second) + continue; + + WorkList.emplace_back(DTN, &I); + } + } + + sort(WorkList.begin(), WorkList.end(), + [](const CmpOrDeref &A, const CmpOrDeref &B) { + if (A.NumIn == B.NumIn && A.NumOut == B.NumOut) { + return A.I->comesBefore(B.I); + } + return A.NumIn < B.NumIn; + }); + + SmallVector DFSInStack; + DenseMap Value2Index; + SmallPtrSet DereferencedObjs; + for (CmpOrDeref &CB : WorkList) { + while (!DFSInStack.empty()) { + auto &E = DFSInStack.back(); + LLVM_DEBUG(dbgs() << "Top of stack : " << E.NumIn << " " << E.NumOut + << "\n"); + LLVM_DEBUG(dbgs() << "CB: " << CB.NumIn << " " << CB.NumOut << "\n"); + bool IsDom = CB.NumIn >= E.NumIn && CB.NumOut <= E.NumOut; + if (IsDom) + break; + LLVM_DEBUG(dbgs() << "Removing " << *E.Obj << " " + << "\n"); + DereferencedObjs.erase(E.Obj); + DFSInStack.pop_back(); + } + + CmpInst::Predicate Pred; + Value *Ptr; + if (match(CB.I, m_ICmp(Pred, m_Value(Ptr), m_Zero())) || + match(CB.I, m_ICmp(Pred, m_Zero(), m_Value(Ptr)))) { + assert(Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE); + const Value *UO = ObjCache.getUnderlyingObjFor(Ptr); + assert(UO); + if (Pred == CmpInst::ICMP_EQ && DereferencedObjs.count(UO)) { + CB.I->replaceAllUsesWith(ConstantInt::getFalse(CB.I->getType())); + CB.I->eraseFromParent(); + NumCmpsRemoved++; + } else if (Pred == CmpInst::ICMP_NE && DereferencedObjs.count(UO)) { + CB.I->replaceAllUsesWith(ConstantInt::getTrue(CB.I->getType())); + CB.I->eraseFromParent(); + NumCmpsRemoved++; + } + continue; + } + + if (auto *MTI = dyn_cast(CB.I)) { + const Value *UOSrc = + ObjCache.getUnderlyingObjFor(MemoryLocation::getForSource(MTI).Ptr); + if (UOSrc) { + DereferencedObjs.insert(UOSrc); + DFSInStack.emplace_back(CB.NumIn, CB.NumOut, UOSrc); + } + + const Value *UODest = + ObjCache.getUnderlyingObjFor(MemoryLocation::getForDest(MTI).Ptr); + if (UODest) { + DereferencedObjs.insert(UODest); + DFSInStack.emplace_back(CB.NumIn, CB.NumOut, UODest); + } + } else { + const Value *UO = + ObjCache.getUnderlyingObjFor(MemoryLocation::get(CB.I).Ptr); + assert(UO); + DereferencedObjs.insert(UO); + DFSInStack.emplace_back(CB.NumIn, CB.NumOut, UO); + } + } + + return Changed; +} + +namespace { + +class NullCheckElimination : public FunctionPass { +public: + static char ID; + + NullCheckElimination() : FunctionPass(ID) { + initializeNullCheckEliminationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis().getDomTree(); + return eliminateNullChecks(F, DT); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired(); + AU.addPreserved(); + AU.addPreserved(); + } +}; + +} // end anonymous namespace + +char NullCheckElimination::ID = 0; + +INITIALIZE_PASS_BEGIN(NullCheckElimination, "nullcheck-elimination", + "Dereferenced Nullcheck Elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(NullCheckElimination, "nullcheck-elimination", + "Dereferenced Nullcheck Elimination", false, false) + +FunctionPass *llvm::createNullCheckEliminationPass() { + return new NullCheckElimination(); +} diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -49,6 +49,7 @@ initializeLoopGuardWideningLegacyPassPass(Registry); initializeGVNLegacyPassPass(Registry); initializeNewGVNLegacyPassPass(Registry); + initializeNullCheckEliminationPass(Registry); initializeEarlyCSELegacyPassPass(Registry); initializeEarlyCSEMemSSALegacyPassPass(Registry); initializeMakeGuardsExplicitLegacyPassPass(Registry);