diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -34,6 +34,7 @@ #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION #define GEN_PASS_DECL_STRIPDEBUGINFO #define GEN_PASS_DECL_PRINTOPSTATS +#define GEN_PASS_DECL_PRINTOPS #define GEN_PASS_DECL_INLINER #define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SYMBOLDCE @@ -71,6 +72,9 @@ /// Creates a pass to strip debug information from a function. std::unique_ptr createStripDebugInfoPass(); +/// Creates a pass which prints the op the pass is run on. +std::unique_ptr createPrintOpsPass(raw_ostream &os = llvm::errs()); + /// Creates a pass which prints the list of ops and the number of occurrences in /// the module. std::unique_ptr createPrintOpStatsPass(raw_ostream &os = llvm::errs()); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -142,6 +142,11 @@ let constructor = "mlir::createLoopInvariantCodeMotionPass()"; } +def PrintOps : Pass<"print-ops"> { + let summary = "Print operations"; + let constructor = "mlir::createPrintOpsPass()"; +} + def PrintOpStats : Pass<"print-op-stats"> { let summary = "Print statistics of operations"; let constructor = "mlir::createPrintOpStatsPass()"; diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ LocationSnapshot.cpp LoopInvariantCodeMotion.cpp OpStats.cpp + PrintOps.cpp SCCP.cpp StripDebugInfo.cpp SymbolDCE.cpp diff --git a/mlir/lib/Transforms/PrintOps.cpp b/mlir/lib/Transforms/PrintOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/PrintOps.cpp @@ -0,0 +1,34 @@ +//===- PrintOps.cpp - Prints operation for debugging----- -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_DEF_PRINTOPS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct PrintOpsPass : public impl::PrintOpsBase { + explicit PrintOpsPass(raw_ostream &os) : os(os) {} + + // Prints the resultant operation statistics post iterating over the module. + void runOnOperation() override; + +private: + raw_ostream &os; +}; +} // namespace + +void PrintOpsPass::runOnOperation() { getOperation()->print(os); } + +std::unique_ptr mlir::createPrintOpsPass(raw_ostream &os) { + return std::make_unique(os); +} \ No newline at end of file diff --git a/mlir/test/IR/print-ops.mlir b/mlir/test/IR/print-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/print-ops.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -allow-unregistered-dialect -print-ops %s -o=/dev/null 2>&1 | FileCheck %s + +func.func @main(tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { +^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): + %0 = arith.addf %arg0, %arg1 : tensor<4xf32> + %unregistered = "unregistered_op"(%0, %arg1) : (tensor<4xf32>,tensor<4xf32>)-> tensor<4xf32> + return %0, %unregistered : tensor<4xf32>, tensor<4xf32> +} + +//CHECK: func.func @main(%[[ARG0:.*]]: tensor<4xf32>, %[[ARG1:.*]]: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { +//CHECK: %[[ADD:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : tensor<4xf32> +//CHECK: %[[UNREGISTERED:.*]] = "unregistered_op"(%[[ADD]], %[[ARG1]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +//CHECK: return %[[ADD]], %[[UNREGISTERED]] : tensor<4xf32>, tensor<4xf32> +//CHECK: }