diff --git a/llvm/include/llvm/Analysis/ContextAnalysis.h b/llvm/include/llvm/Analysis/ContextAnalysis.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Analysis/ContextAnalysis.h @@ -0,0 +1,90 @@ +//===- ContextAnalysis.h - Analysis of BB and Instructions ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CONTEXT_ANALYSIS_H +#define LLVM_CONTEXT_ANALYSIS_H + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" + +namespace llvm { + +class Value; +class PHINode; +class Function; + +class ContextInfo { +public: + /// Construct an empty ContextInfo. + ContextInfo(const Function &F, const ScalarEvolution &SE, const LoopInfo &LI) + : F(F), SE(SE), LI(LI) {} + + void createContext(); + + const std::optional getContextForValue(const Value *V); + const std::optional getContextForBB(const BasicBlock *BB); + const std::optional getContextForInst(const Instruction *PN); + + const std::optional getNameForValue(const Value *V); + const std::optional getNameForBB(const BasicBlock *BB); + const std::optional getNameForInst(const Instruction *PN); + + /// Free the memory used by this class. + void releaseMemory(); + + /// Print out the values currently in the cache. + void print(raw_ostream &OS) const; + + /// Handle invalidation events in the new pass manager. + bool invalidate(Function &, const PreservedAnalyses &, + FunctionAnalysisManager::Invalidator &); + +private: + using ContextMapBB = + DenseMap >; + using ContextMapInst = + DenseMap >; + + ContextMapBB ContextBB; + + ContextMapInst ContextInst; + + const Function &F; + + const ScalarEvolution &SE; + + const LoopInfo &LI; + + void createContextForBB(const BasicBlock *BB); + + void createContextForInst(const Instruction *Inst); +}; + +class ContextAnalysis : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + static AnalysisKey Key; + +public: + using Result = ContextInfo; + ContextInfo run(Function &F, FunctionAnalysisManager &); +}; + +class ContextInfoPrinterPass : public PassInfoMixin { + raw_ostream &OS; + +public: + explicit ContextInfoPrinterPass(raw_ostream &OS) : OS(OS) {} + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // namespace llvm + +#endif // LLVM_CONTEXT_ANALYSIS_H diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -46,6 +46,7 @@ CmpInstAnalysis.cpp CostModel.cpp CodeMetrics.cpp + ContextAnalysis.cpp ConstantFolding.cpp CycleAnalysis.cpp DDG.cpp diff --git a/llvm/lib/Analysis/ContextAnalysis.cpp b/llvm/lib/Analysis/ContextAnalysis.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Analysis/ContextAnalysis.cpp @@ -0,0 +1,251 @@ +//===- ContextAnalysis.cpp - Analysis of BB and Instructions --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ContextAnalysis.h" + +using namespace llvm; + +static StringRef ConcatWithDelimeter(StringRef Base, StringRef Attaching, + StringRef Delimeter = "_") { + if (Base.empty()) + return Attaching; + return StringRef(Base.str() + Delimeter.str() + Attaching.str()); +} + +static StringRef SetName(StringRef Name, StringRef Value) { + if (Name.empty()) + return Value; + return Name; +} + +void ContextInfo::createContext() { + for (const BasicBlock &BB : F) { + createContextForBB(&BB); + for (const Instruction &I : BB) + createContextForInst(&I); + } + return; +} + +const std::optional ContextInfo::getContextForValue(const Value *V) { + if (isa(V)) + return getContextForBB(cast(V)); + if (isa(V)) + return getContextForInst(cast(V)); + return std::nullopt; +} + +const std::optional +ContextInfo::getContextForBB(const BasicBlock *BB) { + if (ContextBB.contains(BB)) + return StringRef(ContextBB[BB].second); + return std::nullopt; +} + +const std::optional +ContextInfo::getContextForInst(const Instruction *Inst) { + if (ContextInst.contains(Inst)) + return StringRef(ContextInst[Inst].second); + return std::nullopt; +} + +const std::optional ContextInfo::getNameForValue(const Value *V) { + if (isa(V)) + return getNameForBB(cast(V)); + if (isa(V)) + return getNameForInst(cast(V)); + return std::nullopt; +} + +const std::optional ContextInfo::getNameForBB(const BasicBlock *BB) { + if (ContextBB.contains(BB)) + return StringRef(ContextBB[BB].first); + return std::nullopt; +} + +const std::optional +ContextInfo::getNameForInst(const Instruction *Inst) { + if (ContextInst.contains(Inst)) + return StringRef(ContextInst[Inst].first); + return std::nullopt; +} + +void ContextInfo::releaseMemory() { + ContextInst.clear(); + ContextBB.clear(); +} + +void ContextInfo::print(raw_ostream &OS) const { + OS << "Basic blocks: \n"; + for (const auto &ItBBMap : ContextBB) { + OS << "BB " << ItBBMap.first->getName(); + if (ItBBMap.second.first == "") { + OS << " hasn't information about this BB\n"; + continue; + } + OS << " has possible name: " << ItBBMap.second.first; + OS << ". And next context: " << ItBBMap.second.second; + OS << "\n"; + } + OS << "Instructions: \n"; + for (const auto &ItInstMap : ContextInst) { + OS << "Inst " << ItInstMap.first->getName(); + if (ItInstMap.second.first == "") { + OS << " hasn't information about this Inst\n"; + continue; + } + OS << "has possible name: " << ItInstMap.second.first; + OS << ". And next context: " << ItInstMap.second.second; + OS << "\n"; + } +} + +bool ContextInfo::invalidate(Function &, const PreservedAnalyses &PA, + FunctionAnalysisManager::Invalidator &) { + // ContextAnalysis is invalidated if it isn't preserved. + auto PAC = PA.getChecker(); + return !(PAC.preserved() || PAC.preservedSet >()); +} + +void ContextInfo::createContextForBB(const BasicBlock *BB) { + StringRef Name(""); + StringRef Ctx(""); + if (ContextBB.contains(BB)) { + Name = ContextBB[BB].first; + Ctx = ContextBB[BB].second; + } + + if (BB->isEntryBlock()) { + Name = SetName(Name, "entry"); + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is entry", " | "); + } + + if (isa(BB->getTerminator())) { + Name = SetName(Name, "unreachable"); + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is unreachable", + " | "); + } + + if (isa(BB->getTerminator())) { + Name = SetName(Name, "exit"); + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is exit", " | "); + } + + if (LI.isLoopHeader(BB)) { + Name = SetName(Name, "loop.header"); + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is loop header", + " | "); + } + + if (auto *L = LI.getLoopFor(BB)) { + if (L->isLoopLatch(BB)) { + Name = SetName(Name, "loop.latch"); + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is loop latch", + " | "); + } + if (L->isLoopExiting(BB)) { + Name = SetName(Name, "loop.exit"); + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is loop exiting", + " | "); + } + Name = SetName(Name, "loop.body"); + if (Ctx == "") + Ctx = ConcatWithDelimeter(Ctx, BB->getName().str() + " is loop body", + " | "); + } + ContextBB[BB] = std::pair{ Name.str(), Ctx.str() }; +} + +void ContextInfo::createContextForInst(const Instruction *Inst) { + std::string Name = ""; + std::string Ctx = ""; + if (ContextInst.contains(Inst)) { + Name = ContextInst[Inst].first; + Ctx = ContextInst[Inst].second; + } + if (isa(Inst)) { + Name = SetName(Name, "phi.node"); + Ctx = + ConcatWithDelimeter(Ctx, Inst->getName().str() + " is phi node", " | "); + } + if (isa(Inst)) { + Name = SetName(Name, "icmp"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is icmp", " | "); + } + if (Inst->isTerminator()) { + Name = SetName(Name, "terminator"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is terminator", + " | "); + } + if (Inst->isUnaryOp()) { + Name = SetName(Name, "unary.operand"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is unary operand", + " | "); + } + if (Inst->isBinaryOp()) { + Name = SetName(Name, "binary.operand"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is binary operand", + " | "); + } + if (Inst->isIntDivRem()) { + Name = SetName(Name, "int.div.rem"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is int div rem", + " | "); + } + if (Inst->isShift()) { + Name = SetName(Name, "shift"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is shift", " | "); + } + if (Inst->isCast()) { + Name = SetName(Name, "cast"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is cast", " | "); + } + if (Inst->isFuncletPad()) { + Name = SetName(Name, "funclet.pad"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is funclet pad", + " | "); + } + if (Inst->isExceptionalTerminator()) { + Name = SetName(Name, "exceptional.terminator"); + Ctx = ConcatWithDelimeter( + Ctx, Inst->getName().str() + " is exceptional terminator", " | "); + } + if (Inst->isLogicalShift()) { + Name = SetName(Name, "logical.shift"); + Ctx = ConcatWithDelimeter(Ctx, Inst->getName().str() + " is logical shift", + " | "); + } + if (Inst->isArithmeticShift()) { + Name = SetName(Name, "ArithmeticShift"); + Ctx = ConcatWithDelimeter( + Ctx, Inst->getName().str() + " is arithmetic shift", " | "); + } + if (Inst->isBitwiseLogicOp()) { + Name = SetName(Name, "BitwiseLogicOp"); + Ctx = ConcatWithDelimeter( + Ctx, Inst->getName().str() + " is bitwise logical operand", " | "); + } + + ContextInst[Inst] = std::pair{ Name, Ctx }; +} + +AnalysisKey ContextAnalysis::Key; +ContextInfo ContextAnalysis::run(Function &F, FunctionAnalysisManager &AM) { + auto &SE = AM.getResult(F); + auto &LI = AM.getResult(F); + ContextInfo CI(F, SE, LI); + CI.createContext(); + return CI; +} + +PreservedAnalyses ContextInfoPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "Context info for function: " << F.getName() << "\n"; + AM.getResult(F).print(OS); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallPrinter.h" #include "llvm/Analysis/CostModel.h" +#include "llvm/Analysis/ContextAnalysis.h" #include "llvm/Analysis/CycleAnalysis.h" #include "llvm/Analysis/DDG.h" #include "llvm/Analysis/DDGPrinter.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -223,6 +223,7 @@ FUNCTION_ANALYSIS("block-freq", BlockFrequencyAnalysis()) FUNCTION_ANALYSIS("branch-prob", BranchProbabilityAnalysis()) FUNCTION_ANALYSIS("cycles", CycleAnalysis()) +FUNCTION_ANALYSIS("context-info", ContextAnalysis()) FUNCTION_ANALYSIS("domtree", DominatorTreeAnalysis()) FUNCTION_ANALYSIS("postdomtree", PostDominatorTreeAnalysis()) FUNCTION_ANALYSIS("demanded-bits", DemandedBitsAnalysis()) @@ -354,6 +355,7 @@ FUNCTION_PASS("print", BlockFrequencyPrinterPass(dbgs())) FUNCTION_PASS("print", BranchProbabilityPrinterPass(dbgs())) FUNCTION_PASS("print", CostModelPrinterPass(dbgs())) +FUNCTION_PASS("print", ContextInfoPrinterPass(dbgs())) FUNCTION_PASS("print", CycleInfoPrinterPass(dbgs())) FUNCTION_PASS("print", DependenceAnalysisPrinterPass(dbgs())) FUNCTION_PASS("print", DivergenceAnalysisPrinterPass(dbgs())) diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -21,6 +21,7 @@ CFGTest.cpp CGSCCPassManagerTest.cpp ConstraintSystemTest.cpp + ContextAnalysisTest.cpp DDGTest.cpp DivergenceAnalysisTest.cpp DomTreeUpdaterTest.cpp diff --git a/llvm/unittests/Analysis/ContextAnalysisTest.cpp b/llvm/unittests/Analysis/ContextAnalysisTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Analysis/ContextAnalysisTest.cpp @@ -0,0 +1,120 @@ +//===- ContextAnalysisTest.cpp - Context analysis unit tests --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ContextAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static ContextInfo buildCI(Function &F) { + TargetLibraryInfoImpl TLII; + TargetLibraryInfo TLI(TLII); + AssumptionCache AC(F); + DominatorTree DT(F); + LoopInfo LI(DT); + ScalarEvolution SE(F, TLI, AC, DT, LI); + ContextInfo CI(F, SE, LI); + CI.createContext(); + return CI; +} + +TEST(ContextAnalysisTest, SimpleFunction) { + LLVMContext Context; + Module M("ContextAnalysisTest", Context); + + // Generate a function like below: + // define void @foo() { + // entry: + // br label %for.cond + // + // for.cond: ; preds = %for.end, %entry + // %pnode = phi i32 [ 10, %entry ], [ %dec, %for.inc ] + // %cmp.unreach = icmp sgt i32 %pnode, -10 + // br i1 %cmp.unreach, label %for.inc, label %unreach + // + // for.inc: ; preds = %for.cond + // %dec = add nsw i32 %pnode, -1 + // br label %for.end + // + // unreach: ; preds = %for.cond + // unreachable + // + // for.end: ; preds = %for.inc + // %cmp.exit = icmp sgt i32 %dec, 0 + // br i1 %cmp.exit, label %exit, label %for.cond + // + // exit: ; preds = %for.end + // ret void + // } + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), {}, false); + Function *F = Function::Create(FTy, Function::ExternalLinkage, "foo", M); + + BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F); + BasicBlock *CondBB = BasicBlock::Create(Context, "for.cond", F); + BasicBlock *EndBB = BasicBlock::Create(Context, "for.end", F); + BasicBlock *IncBB = BasicBlock::Create(Context, "for.inc", F, EndBB); + BasicBlock *UnreachableBB = BasicBlock::Create(Context, "unreach", F, EndBB); + BasicBlock *ExitBB = BasicBlock::Create(Context, "exit", F); + + Type *I32Ty = Type::getInt32Ty(Context); + auto *PN = PHINode::Create(I32Ty, 2, "pnode", CondBB); + PN->addIncoming(ConstantInt::get(Context, APInt(32, 10)), EntryBB); + auto *Dec = BinaryOperator::CreateNSWAdd( + PN, ConstantInt::get(Context, APInt(32, -1)), "dec", IncBB); + PN->addIncoming(Dec, IncBB); + + auto *CmpUnreach = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, PN, + ConstantInt::get(Context, APInt(32, -10)), + "cmp.unreach", CondBB); + auto *CmpExit = CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_SGT, Dec, + ConstantInt::get(Context, APInt(32, 0)), + "cmp.exit", EndBB); + + BranchInst::Create(CondBB, EntryBB); + BranchInst::Create(IncBB, UnreachableBB, CmpUnreach, CondBB); + BranchInst::Create(EndBB, IncBB); + new UnreachableInst(Context, UnreachableBB); + BranchInst::Create(ExitBB, CondBB, CmpExit, EndBB); + ReturnInst::Create(Context, nullptr, ExitBB); + + ContextInfo CI = buildCI(*F); + EXPECT_EQ(CI.getContextForInst(PN), "pnode is phi node"); + EXPECT_EQ(CI.getNameForInst(PN), "phi.node"); + + EXPECT_EQ(CI.getContextForInst(CmpExit), "cmp.exit is icmp"); + EXPECT_EQ(CI.getNameForInst(CmpExit), "icmp"); + + EXPECT_EQ(CI.getContextForInst(Dec), "dec is binary operand"); + EXPECT_EQ(CI.getNameForInst(Dec), "binary.operand"); + + EXPECT_EQ(CI.getContextForBB(EntryBB), "entry is entry"); + EXPECT_EQ(CI.getNameForBB(EntryBB), "entry"); + + EXPECT_EQ(CI.getContextForBB(CondBB), + "for.cond is loop header | for.cond is loop exiting"); + EXPECT_EQ(CI.getNameForBB(CondBB), "loop.header"); + + EXPECT_EQ(CI.getContextForBB(IncBB), "for.inc is loop body"); + EXPECT_EQ(CI.getNameForBB(IncBB), "loop.body"); + + EXPECT_EQ(CI.getContextForBB(UnreachableBB), "unreach is unreachable"); + EXPECT_EQ(CI.getNameForBB(UnreachableBB), "unreachable"); + + EXPECT_EQ(CI.getContextForBB(EndBB), + "for.end is loop latch | for.end is loop exiting"); + EXPECT_EQ(CI.getNameForBB(EndBB), "loop.latch"); + + EXPECT_EQ(CI.getContextForBB(ExitBB), "exit is exit"); + EXPECT_EQ(CI.getNameForBB(ExitBB), "exit"); +}