diff --git a/llvm/include/llvm/Analysis/TBAAPrinter.h b/llvm/include/llvm/Analysis/TBAAPrinter.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Analysis/TBAAPrinter.h @@ -0,0 +1,36 @@ +//===-- TBAAPrinter.h - TBAA printer external interface ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a 'view-tbaa' analysis pass, which allows users to see +// graphs of the TBAA type tree for each function in the program. +// +// This file defines external functions that can be called to explicitly +// instantiate the TBAA printer. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_TBAAPRINTER_H +#define LLVM_ANALYSIS_TBAAPRINTER_H + +#include "llvm/IR/TBAA.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { +class TBAAViewerPass : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +class TBAAPrinterPass : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // End llvm namespace +#endif diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -445,6 +445,8 @@ void initializeTargetLibraryInfoWrapperPassPass(PassRegistry&); void initializeTargetPassConfigPass(PassRegistry&); void initializeTargetTransformInfoWrapperPassPass(PassRegistry&); +void initializeTBAAPrinterLegacyPassPass(PassRegistry&); +void initializeTBAAViewerLegacyPassPass(PassRegistry&); void initializeThreadSanitizerLegacyPassPass(PassRegistry&); void initializeTLSVariableHoistLegacyPassPass(PassRegistry &); void initializeTwoAddressInstructionPassPass(PassRegistry&); diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -130,6 +130,7 @@ StackSafetyAnalysis.cpp SyncDependenceAnalysis.cpp SyntheticCountsUtils.cpp + TBAAPrinter.cpp TFUtils.cpp TargetLibraryInfo.cpp TargetTransformInfo.cpp diff --git a/llvm/lib/Analysis/TBAAPrinter.cpp b/llvm/lib/Analysis/TBAAPrinter.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Analysis/TBAAPrinter.cpp @@ -0,0 +1,312 @@ +//===- TBAAPrinter.cpp - Graphival viewer for TBAA metadata ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Trying to make sense of TBAA metadata with dot and LLVM. Because somebody +// needs to. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TBAAPrinter.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/TBAA.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/DOTGraphTraits.h" + +#include + +using namespace llvm; + +static cl::opt + TBAAFuncName("tbaa-func-name", cl::Hidden, + cl::desc("The name of a function (or its substring)" + " whose TBAA is viewed/printed.")); + +static cl::opt TBAADotFilenamePrefix( + "tbaa-dot-filename-prefix", cl::Hidden, + cl::desc("The prefix used for the TBAA dot file names.")); + +static cl::opt ShowTBAAInference("tbaa-show-inference", cl::Hidden, + cl::desc("Show edges indicating inference of struct access tags")); + +struct TbaaGraphNode; +struct TbaaGraph; + +struct TbaaGraphNode { + const Instruction *I = nullptr; + const MDNode *TbaaNode = nullptr; + bool IsAccessTag = false; + std::vector Children; + explicit TbaaGraphNode(const Instruction &I) : I(&I) {} + TbaaGraphNode(const MDNode *N, bool Access) + : TbaaNode(N), IsAccessTag(Access) {} + static TbaaGraphNode *makeNode(TbaaGraph &Graph, const Instruction &I); +}; + +struct TbaaGraph { + std::vector Nodes; + DenseMap MetadataIndices; + const Function &F; + + TbaaGraph(const Function &F) : F(F) { + for (auto &I : instructions(F)) { + if (auto Node = TbaaGraphNode::makeNode(*this, I)) { + Nodes.push_back(Node); + } + } + } + + TbaaGraphNode *getTbaaNode(const MDNode *Tbaa, bool IsAccessTag = false) { + auto It = MetadataIndices.find(Tbaa); + if (It != MetadataIndices.end()) + return It->second; + auto Node = new TbaaGraphNode(Tbaa, IsAccessTag); + Nodes.push_back(Node); + MetadataIndices[Tbaa] = Node; + if (IsAccessTag) { + assert(isStructPathTBAA(Tbaa) && "access tags should be struct-path"); + TBAAStructTagNode Tag(Tbaa); + Node->Children.push_back(getTbaaNode(Tag.getBaseType())); + Node->Children.push_back(getTbaaNode(Tag.getAccessType())); + if (ShowTBAAInference) { + uint64_t Offset = Tag.getOffset(); + TBAAStructTypeNode Base(Tag.getBaseType()); + TBAAStructTypeNode Scalar(Tag.getAccessType()); + if (Base != Scalar) { + TBAAStructTypeNode Infer = Base.getField(Offset); + while (Infer != Scalar) { + Node->Children.push_back(getTbaaNode(Infer.getNode())); + Infer = Infer.getField(Offset); + } + } + } + } else { + TBAANode Wrapper(Tbaa); + if (Wrapper.isScalar() || Wrapper.isRoot()) { + auto ParentNode = Wrapper.getParent().getNode(); + if (ParentNode) + Node->Children.push_back(getTbaaNode(ParentNode)); + } else { + TBAAStructTypeNode StructTy(Tbaa); + for (unsigned i = 0; i < StructTy.getNumFields(); i++) { + Node->Children.push_back( + getTbaaNode(StructTy.getFieldType(i).getNode())); + } + } + } + return Node; + } +}; + +TbaaGraphNode *TbaaGraphNode::makeNode(TbaaGraph &Graph, const Instruction &I) { + AAMDNodes Nodes = I.getAAMetadata(); + const MDNode *TBAA = Nodes.TBAA; + if (!TBAA) return nullptr; + auto Node = new TbaaGraphNode(I); + Node->Children.push_back(Graph.getTbaaNode(TBAA, true)); + return Node; +} + +namespace llvm { +template<> struct GraphTraits { + using GraphType = const TbaaGraph; + typedef const TbaaGraphNode *NodeRef; + typedef std::vector::const_iterator ChildIteratorType; + static ChildIteratorType child_begin(NodeRef N) { + return N->Children.begin(); + } + static ChildIteratorType child_end(NodeRef N) { + return N->Children.end(); + } + typedef std::vector::const_iterator nodes_iterator; + static nodes_iterator nodes_begin(GraphType &G) { + return G.Nodes.begin(); + } + static nodes_iterator nodes_end(GraphType &G) { + return G.Nodes.end(); + } + static unsigned size(GraphType *G) { + return G->Nodes.size(); + } +}; + +template<> struct DOTGraphTraits : public DefaultDOTGraphTraits { + using GraphTraits = llvm::GraphTraits; + DOTGraphTraits(bool ShortNames = false) : DefaultDOTGraphTraits(ShortNames) {} + + static std::string getGraphName(const TbaaGraph &G) { + return "TBAA for '" + G.F.getName().str() + "' function"; + } + + static std::string getNodeLabel(GraphTraits::NodeRef N, const TbaaGraph &G) { + std::string Str; + raw_string_ostream OS(Str); + if (N->I) { + OS << *N->I; + } else { + const Module *M = G.F.getParent(); + if (N->IsAccessTag) { + OS << "access tag "; + N->TbaaNode->printAsOperand(OS, M); + TBAAStructTagNode Tag(N->TbaaNode); + OS << "(offset " << Tag.getOffset(); + if (Tag.isNewFormat()) + OS << ", size " << Tag.getSize(); + if (Tag.isTypeImmutable()) + OS << ", immutable"; + OS << ")"; + } else { + TBAANode Node(N->TbaaNode); + if (Node.isRoot()) { + OS << "root "; + } else if (Node.isScalar()) { + OS << "scalar "; + } else { + OS << "struct "; + } + Node.getNode()->printAsOperand(OS, M); + OS << "(" << Node.getId(); + if (Node.isNewFormat() && !Node.isRoot()) + OS << ", size " << Node.getSize(); + OS << ")"; + } + } + return Str; + } + + static unsigned numEdgeSourceLabels(GraphTraits::NodeRef N) { + if (N->TbaaNode && N->IsAccessTag) { + return ShowTBAAInference ? 3 : 2; + } else if (N->TbaaNode) { + TBAAStructTypeNode Node(N->TbaaNode); + if (Node.isRoot() || Node.isScalar()) + return 0; + else + return TBAAStructTypeNode(N->TbaaNode).getNumFields(); + } else { + return 0; + } + } + + static std::string getEdgeSourceLabel(GraphTraits::NodeRef N, + const GraphTraits::ChildIteratorType &E) { + unsigned Index = E - N->Children.begin(); + if (numEdgeSourceLabels(N) == 0) { return ""; } + if (N->IsAccessTag) { + if (Index == 0) + return "base"; + if (Index == 1) + return "access"; + return "inferred"; + } + std::string Str; + raw_string_ostream OS(Str); + TBAAStructTypeNode StructTy(N->TbaaNode); + OS << StructTy.getFieldOffset(Index); + return Str; + } + + static bool renderGraphFromBottomUp() { return true; } +}; + +} // end namespace llvm + +static void writeTBAAToDotFile(Function &F) { + std::string Filename = + (TBAADotFilenamePrefix + "." + F.getName() + ".dot").str(); + errs() << "Writing '" << Filename << "'..."; + + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::OF_Text); + + TbaaGraph G(F); + if (!EC) + WriteGraph(File, G); + else + errs() << " error opening file for writing!"; + errs() << "\n"; +} + +static void viewTBAAForFunction(Function &F) { + TbaaGraph G(F); + ViewGraph(G, Twine("TBAA for function ") + F.getName()); +} + +namespace { +struct TBAAViewerLegacyPass : public FunctionPass { + static char ID; // Pass identifcation, replacement for typeid + TBAAViewerLegacyPass() : FunctionPass(ID) { + initializeTBAAViewerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (!TBAAFuncName.empty() && !F.getName().contains(TBAAFuncName)) + return false; + viewTBAAForFunction(F); + return false; + } + + void print(raw_ostream &OS, const Module * = nullptr) const override {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + FunctionPass::getAnalysisUsage(AU); + AU.setPreservesAll(); + } +}; +} // namespace + +char TBAAViewerLegacyPass::ID = 0; +INITIALIZE_PASS(TBAAViewerLegacyPass, "view-tbaa", "View TBAA of function", false, + true) + +PreservedAnalyses TBAAViewerPass::run(Function &F, FunctionAnalysisManager &AM) { + if (!TBAAFuncName.empty() && !F.getName().contains(TBAAFuncName)) + return PreservedAnalyses::all(); + viewTBAAForFunction(F); + return PreservedAnalyses::all(); +} + +namespace { +struct TBAAPrinterLegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + TBAAPrinterLegacyPass() : FunctionPass(ID) { + initializeTBAAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (!TBAAFuncName.empty() && !F.getName().contains(TBAAFuncName)) + return false; + writeTBAAToDotFile(F); + return false; + } + + void print(raw_ostream &OS, const Module * = nullptr) const override {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + FunctionPass::getAnalysisUsage(AU); + AU.setPreservesAll(); + } +}; +} // namespace + +char TBAAPrinterLegacyPass::ID = 0; +INITIALIZE_PASS(TBAAPrinterLegacyPass, "dot-tbaa", + "Print TBAA of function to 'dot' file", false, true) + +PreservedAnalyses TBAAPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + if (!TBAAFuncName.empty() && !F.getName().contains(TBAAFuncName)) + return PreservedAnalyses::all(); + writeTBAAToDotFile(F); + return PreservedAnalyses::all(); +} diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -70,6 +70,7 @@ #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/TBAAPrinter.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRPrintingPasses.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -261,6 +261,7 @@ FUNCTION_PASS("dot-cfg-only", CFGOnlyPrinterPass()) FUNCTION_PASS("dot-dom", DomTreePrinterPass()) FUNCTION_PASS("dot-dom-only", DomTreeOnlyPrinterPass()) +FUNCTION_PASS("dot-tbaa", TBAAPrinterPass()) FUNCTION_PASS("fix-irreducible", FixIrreduciblePass()) FUNCTION_PASS("flattencfg", FlattenCFGPass()) FUNCTION_PASS("make-guards-explicit", MakeGuardsExplicitPass()) @@ -363,6 +364,7 @@ FUNCTION_PASS("verify", ScalarEvolutionVerifierPass()) FUNCTION_PASS("view-cfg", CFGViewerPass()) FUNCTION_PASS("view-cfg-only", CFGOnlyViewerPass()) +FUNCTION_PASS("view-tbaa", TBAAViewerPass()) FUNCTION_PASS("tlshoist", TLSVariableHoistPass()) FUNCTION_PASS("transform-warning", WarnMissedTransformationsPass()) FUNCTION_PASS("tsan", ThreadSanitizerPass())