Index: mlir/include/mlir/IR/Block.h =================================================================== --- mlir/include/mlir/IR/Block.h +++ mlir/include/mlir/IR/Block.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_BLOCK_H #define MLIR_IR_BLOCK_H +#include "mlir/IR/PrinterHook.h" #include "mlir/IR/BlockSupport.h" #include "mlir/IR/Visitors.h" @@ -336,7 +337,7 @@ } void print(raw_ostream &os); - void print(raw_ostream &os, AsmState &state); + void print(raw_ostream &os, AsmState &state, PrinterHookBase *hook = nullptr); void dump(); /// Print out the name of the block without printing its body. Index: mlir/include/mlir/IR/Operation.h =================================================================== --- mlir/include/mlir/IR/Operation.h +++ mlir/include/mlir/IR/Operation.h @@ -241,7 +241,7 @@ bool isBeforeInBlock(Operation *other); void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None); - void print(raw_ostream &os, AsmState &state); + void print(raw_ostream &os, AsmState &state, PrinterHookBase *hook = nullptr); void dump(); //===--------------------------------------------------------------------===// Index: mlir/include/mlir/IR/PrinterHook.h =================================================================== --- /dev/null +++ mlir/include/mlir/IR/PrinterHook.h @@ -0,0 +1,40 @@ +//===- PrinterHook.h - Printer hook for annotations -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides the base class of Annotation printer hooks +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_PRINTER_HOOK_H +#define MLIR_IR_PRINTER_HOOK_H + +#include "mlir/Support/LLVM.h" + +namespace mlir { + +class Operation; +class Block; + +/// Printer Hook to add annotations before an after IR constructs. +/// It allows analysis to print there results directly within the IR making it +/// easier for human to read analysis results. +class PrinterHookBase { +public: + /// Every annotations should be a self-contained comment. so it must start + /// with "//" (maybe after indenting) and finish with a newline. otherwise it + /// will likely not be possible to parse the generated IR. + virtual void printCommentBeforeOp(Operation *op, raw_ostream &os, + unsigned currentIndent) = 0; + virtual void printCommentBeforeBlock(Block *block, raw_ostream &os, + unsigned currentIndent) = 0; + virtual ~PrinterHookBase(); +}; + +} // namespace mlir + +#endif Index: mlir/lib/IR/AsmPrinter.cpp =================================================================== --- mlir/lib/IR/AsmPrinter.cpp +++ mlir/lib/IR/AsmPrinter.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PrinterHook.h" #include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Verifier.h" #include "llvm/ADT/APFloat.h" @@ -2718,8 +2719,10 @@ using Impl = AsmPrinter::Impl; using Impl::printType; - explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) - : Impl(os, state), OpAsmPrinter(static_cast(*this)) {} + explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state, + PrinterHookBase *hook = nullptr) + : Impl(os, state), OpAsmPrinter(static_cast(*this)), + annotationHook(hook) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2905,6 +2908,9 @@ // This is the current indentation level for nested structures. unsigned currentIndent = 0; + + /// Hook to print annotations before or after blocks and operations + PrinterHookBase *annotationHook; }; } // namespace @@ -3010,6 +3016,9 @@ } void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) { + if (annotationHook) + annotationHook->printCommentBeforeOp(op, os, currentIndent); + // Track the location of this operation. state.registerOperationLocation(op, newLine.curLine, currentIndent); @@ -3194,6 +3203,10 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, bool printBlockTerminator) { + + if (annotationHook) + annotationHook->printCommentBeforeBlock(block, os, currentIndent); + // Print the block label and argument list if requested. if (printBlockArgs) { os.indent(currentIndent); @@ -3491,8 +3504,8 @@ AsmState state(op, printerFlags); print(os, state); } -void Operation::print(raw_ostream &os, AsmState &state) { - OperationPrinter printer(os, state.getImpl()); +void Operation::print(raw_ostream &os, AsmState &state, PrinterHookBase *hook) { + OperationPrinter printer(os, state.getImpl(), hook); if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { state.getImpl().initializeAliases(this); printer.printTopLevelOperation(this); @@ -3519,8 +3532,8 @@ AsmState state(parentOp); print(os, state); } -void Block::print(raw_ostream &os, AsmState &state) { - OperationPrinter(os, state.getImpl()).print(this); +void Block::print(raw_ostream &os, AsmState &state, PrinterHookBase *hook) { + OperationPrinter(os, state.getImpl(), hook).print(this); } void Block::dump() { print(llvm::errs()); } @@ -3539,3 +3552,5 @@ OperationPrinter printer(os, state.getImpl()); printer.printBlockName(this); } + +PrinterHookBase::~PrinterHookBase() = default;