Index: llvm/include/llvm/Analysis/CFGPrinter.h =================================================================== --- llvm/include/llvm/Analysis/CFGPrinter.h +++ llvm/include/llvm/Analysis/CFGPrinter.h @@ -28,6 +28,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/GraphWriter.h" +#include namespace llvm { class CFGViewerPass : public PassInfoMixin { @@ -141,8 +142,18 @@ return OS.str(); } - static std::string getCompleteNodeLabel(const BasicBlock *Node, - DOTFuncInfo *) { + static void eraseComment(std::string &OutStr, unsigned &I, unsigned Idx) { + OutStr.erase(OutStr.begin() + I, OutStr.begin() + Idx); + --I; + } + + static std::string getCompleteNodeLabel( + const BasicBlock *Node, DOTFuncInfo *, + std::function + HandleBasicBlock = [](raw_string_ostream &OS, + const BasicBlock &Node) -> void { OS << Node; }, + std::function HandleComment = + eraseComment) { enum { MaxColumns = 80 }; std::string Str; raw_string_ostream OS(Str); @@ -152,7 +163,7 @@ OS << ":"; } - OS << *Node; + HandleBasicBlock(OS, *Node); std::string OutStr = OS.str(); if (OutStr[0] == '\n') OutStr.erase(OutStr.begin()); @@ -168,8 +179,7 @@ LastSpace = 0; } else if (OutStr[i] == ';') { // Delete comments! unsigned Idx = OutStr.find('\n', i + 1); // Find end of line - OutStr.erase(OutStr.begin() + i, OutStr.begin() + Idx); - --i; + HandleComment(OutStr, i, Idx); } else if (ColNum == MaxColumns) { // Wrap lines. // Wrap very long names even though we can't find a space. if (!LastSpace) Index: llvm/lib/Analysis/MemorySSA.cpp =================================================================== --- llvm/lib/Analysis/MemorySSA.cpp +++ llvm/lib/Analysis/MemorySSA.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CFGPrinter.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Config/llvm-config.h" @@ -59,6 +60,11 @@ #define DEBUG_TYPE "memoryssa" +static cl::opt + DotCFGMSSA("cfg-dot-mssa", + cl::value_desc("file name for generated dot file"), + cl::desc("file name for generated dot file"), cl::init("")); + INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) @@ -2262,9 +2268,94 @@ AU.addRequired(); } +class DOTFuncMSSAInfo { +private: + const Function &F; + MemorySSAAnnotatedWriter MSSAWriter; + +public: + DOTFuncMSSAInfo(const Function &F, MemorySSA &MSSA) + : F(F), MSSAWriter(&MSSA) {} + + const Function *getFunction() { return &F; } + MemorySSAAnnotatedWriter &getWriter() { return MSSAWriter; } +}; + +template <> +struct GraphTraits : public GraphTraits { + static NodeRef getEntryNode(DOTFuncMSSAInfo *CFGInfo) { + return &(CFGInfo->getFunction()->getEntryBlock()); + } + + // nodes_iterator/begin/end - Allow iteration over all nodes in the graph + using nodes_iterator = pointer_iterator; + + static nodes_iterator nodes_begin(DOTFuncMSSAInfo *CFGInfo) { + return nodes_iterator(CFGInfo->getFunction()->begin()); + } + + static nodes_iterator nodes_end(DOTFuncMSSAInfo *CFGInfo) { + return nodes_iterator(CFGInfo->getFunction()->end()); + } + + static size_t size(DOTFuncMSSAInfo *CFGInfo) { + return CFGInfo->getFunction()->size(); + } +}; + +template <> +struct DOTGraphTraits : public DefaultDOTGraphTraits { + + DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + + static std::string getGraphName(DOTFuncMSSAInfo *CFGInfo) { + return "MSSA CFG for '" + CFGInfo->getFunction()->getName().str() + + "' function"; + } + + std::string getNodeLabel(const BasicBlock *Node, DOTFuncMSSAInfo *CFGInfo) { + return DOTGraphTraits::getCompleteNodeLabel( + Node, nullptr, + [CFGInfo](raw_string_ostream &OS, const BasicBlock &BB) -> void { + BB.print(OS, &CFGInfo->getWriter(), true, true); + }, + [](std::string &S, unsigned &I, unsigned Idx) -> void { + std::string Str = S.substr(I, Idx - I); + StringRef SR = Str; + if (SR.count(" = MemoryDef(") || SR.count(" = MemoryPhi(") || + SR.count(" = MemoryUse(")) + return; + DOTGraphTraits::eraseComment(S, I, Idx); + }); + } + + static std::string getEdgeSourceLabel(const BasicBlock *Node, + const_succ_iterator I) { + return DOTGraphTraits::getEdgeSourceLabel(Node, I); + } + + /// Display the raw branch weights from PGO. + std::string getEdgeAttributes(const BasicBlock *Node, const_succ_iterator I, + DOTFuncMSSAInfo *CFGInfo) { + return ""; + } + + std::string getNodeAttributes(const BasicBlock *Node, + DOTFuncMSSAInfo *CFGInfo) { + return getNodeLabel(Node, CFGInfo).find(';') != std::string::npos + ? "style=filled, fillcolor=lightpink" + : ""; + } +}; + bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { auto &MSSA = getAnalysis().getMSSA(); - MSSA.print(dbgs()); + if (DotCFGMSSA != "") { + DOTFuncMSSAInfo CFGInfo(F, MSSA); + WriteGraph(&CFGInfo, "", false, "MSSA", DotCFGMSSA); + } else + MSSA.print(dbgs()); + if (VerifyMemorySSA) MSSA.verifyMemorySSA(); return false; @@ -2290,8 +2381,14 @@ PreservedAnalyses MemorySSAPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { - OS << "MemorySSA for function: " << F.getName() << "\n"; - AM.getResult(F).getMSSA().print(OS); + auto &MSSA = AM.getResult(F).getMSSA(); + if (DotCFGMSSA != "") { + DOTFuncMSSAInfo CFGInfo(F, MSSA); + WriteGraph(&CFGInfo, "", false, "MSSA", DotCFGMSSA); + } else { + OS << "MemorySSA for function: " << F.getName() << "\n"; + MSSA.print(OS); + } return PreservedAnalyses::all(); }