diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -31,10 +31,11 @@ #define GEN_PASS_DECL_CANONICALIZER #define GEN_PASS_DECL_CONTROLFLOWSINK #define GEN_PASS_DECL_CSEPASS +#define GEN_PASS_DECL_INLINER #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION -#define GEN_PASS_DECL_STRIPDEBUGINFO +#define GEN_PASS_DECL_PRINTIRPASS #define GEN_PASS_DECL_PRINTOPSTATS -#define GEN_PASS_DECL_INLINER +#define GEN_PASS_DECL_STRIPDEBUGINFO #define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SYMBOLDCE #define GEN_PASS_DECL_SYMBOLPRIVATIZE @@ -65,7 +66,7 @@ std::unique_ptr createCSEPass(); /// Creates a pass to print IR on the debug stream. -std::unique_ptr createPrintIRPass(); +std::unique_ptr createPrintIRPass(const PrintIRPassOptions & = {}); /// Creates a pass that generates IR to verify ops at runtime. std::unique_ptr createGenerateRuntimeVerificationPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -92,6 +92,9 @@ purposes to inspect the IR at a specific point in the pipeline. }]; let constructor = "mlir::createPrintIRPass()"; + let options = [ + Option<"label", "label", "std::string", /*default=*/"", "Label">, + ]; } def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> { diff --git a/mlir/lib/Transforms/PrintIR.cpp b/mlir/lib/Transforms/PrintIR.cpp --- a/mlir/lib/Transforms/PrintIR.cpp +++ b/mlir/lib/Transforms/PrintIR.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/Debug.h" namespace mlir { namespace { @@ -15,15 +17,21 @@ #include "mlir/Transforms/Passes.h.inc" struct PrintIRPass : public impl::PrintIRPassBase { - PrintIRPass() = default; - - void runOnOperation() override { getOperation()->dump(); } + using impl::PrintIRPassBase::PrintIRPassBase; + + void runOnOperation() override { + llvm::dbgs() << "// -----// IR Dump"; + if (!this->label.empty()) + llvm::dbgs() << " " << this->label; + llvm::dbgs() << " //----- //\n"; + getOperation()->dump(); + } }; } // namespace -std::unique_ptr createPrintIRPass() { - return std::make_unique(); +std::unique_ptr createPrintIRPass(const PrintIRPassOptions &options) { + return std::make_unique(options); } } // namespace mlir \ No newline at end of file