diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -109,6 +109,13 @@ /// in terms of shapes of its input operands. std::unique_ptr createResolveShapedTypeResultDimsPass(); +/// Inserts CallOps to print IR and the content of memrefs for every buffer +/// argument of FuncOp and for every operation within the function depending on +/// the `extract_values_func` filter. +void embedMemRefPrints( + ModuleOp module, + function_ref(Operation *)> extract_values_func); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms ComposeSubView.cpp + EmbedMemRefPrints.cpp ExpandOps.cpp FoldSubViewOps.cpp MultiBuffer.cpp @@ -18,6 +19,7 @@ MLIRArithmetic MLIRFunc MLIRInferTypeOpInterface + MLIRLLVMIR MLIRLoopLikeInterface MLIRMemRef MLIRPass diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmbedMemRefPrints.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmbedMemRefPrints.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/EmbedMemRefPrints.cpp @@ -0,0 +1,182 @@ +//===- EmbedMemRefPrints.cpp - Utility to print content of memrefs -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "llvm/Support/FormatVariadic.h" + +using namespace mlir; + +using LLVM::LLVMFuncOp; +using ExtractValuesFunc = function_ref(Operation *)>; + +constexpr StringRef kPrintStringFuncName = "print_c_string"; + +static FlatSymbolRefAttr getOrInsertLLVMFunction(StringRef func_name, + Type func_type, Operation *op, + OpBuilder *b) { + auto module = op->getParentOfType(); + auto tf_func = module.lookupSymbol(func_name); + if (!tf_func) { + OpBuilder::InsertionGuard guard(*b); + b->setInsertionPointToStart(module.getBody()); + tf_func = b->create(b->getUnknownLoc(), func_name, func_type); + } + return SymbolRefAttr::get(b->getContext(), func_name); +} + +static std::string getGlobalName(StringRef base, StringRef content) { + return llvm::formatv("{0}_{1}", base, llvm::hash_value(content)); +} + +static Value createOrFindGlobalStringConstant(Location loc, + StringRef global_name, + StringRef content, OpBuilder *b) { + auto module = + b->getInsertionBlock()->getParentOp()->getParentOfType(); + Operation *global_constant = SymbolTable::lookupNearestSymbolFrom( + module, b->getStringAttr(global_name)); + if (global_constant) { + Value global_ptr = b->create( + loc, cast(global_constant)); + Value c0 = + b->create(loc, b->getI64Type(), b->getIndexAttr(0)); + return b->create( + loc, LLVM::LLVMPointerType::get(b->getIntegerType(8)), global_ptr, + ValueRange{c0, c0}); + } + return LLVM::createGlobalString(loc, *b, global_name, content, + LLVM::Linkage::Internal); +} + +static Operation *emitMemRefPrint(Location loc, Type element_type, Value arg, + OpBuilder *b) { + StringRef func_name; + if (element_type.isF32()) + func_name = "print_memref_f32"; + if (element_type.isF64()) + func_name = "print_memref_f64"; + if (element_type.isInteger(32)) + func_name = "print_memref_i32"; + if (element_type.isInteger(64) || element_type.isIndex()) + func_name = "print_memref_i64"; + assert(!func_name.empty() && + "Did not find a print function for the element type"); + + auto caller_func = + b->getInsertionBlock()->getParent()->getParentOfType(); + auto func_name_attr = b->getStringAttr(func_name); + + auto callee_func = + SymbolTable::lookupNearestSymbolFrom(caller_func, func_name_attr); + if (!callee_func) { + OpBuilder::InsertionGuard insertGuard(*b); + + auto module = caller_func->getParentOfType(); + b->setInsertionPointToStart(module.getBody()); + auto func_type = FunctionType::get(b->getContext(), arg.getType(), + /*results=*/llvm::None); + callee_func = b->create(module.getLoc(), func_name, func_type); + callee_func.setPrivate(); + } + return b->create(loc, callee_func, arg); +} + +static bool isElementTypePrintalble(Type element_type) { + return element_type.isF32() || element_type.isF64() || + element_type.isInteger(32) || element_type.isInteger(64) || + element_type.isIndex(); +} + +static void emitValuePrint(Location loc, Value value, OpBuilder *b) { + auto memref_type = value.getType(); + if (auto unranked_type = memref_type.dyn_cast()) { + Type element_type = unranked_type.getElementType(); + if (!isElementTypePrintalble(element_type)) + return; + + emitMemRefPrint(loc, element_type, value, b); + } + if (auto ranked_type = memref_type.dyn_cast()) { + Type element_type = ranked_type.getElementType(); + if (!isElementTypePrintalble(element_type)) + return; + + if (element_type.isIndex()) { + element_type = b->getI64Type(); + ranked_type = MemRefType::get(ranked_type.getShape(), element_type, + ranked_type.getLayout(), + ranked_type.getMemorySpace()); + value = b->create(loc, ranked_type, value); + } + + auto unranked_type = UnrankedMemRefType::get( + element_type, ranked_type.getMemorySpaceAsInt()); + Value unranked_memref = + b->create(loc, unranked_type, value); + emitMemRefPrint(loc, element_type, unranked_memref, b); + } +} + +static void emitOperationPrint(Operation *op, OpBuilder *b) { + std::string debug_str = "\n\nPrint memref content after the following op\n"; + llvm::raw_string_ostream output_stream(debug_str); + + mlir::OpPrintingFlags flags; + op->print(output_stream, flags); + output_stream << "\n\n"; + + Location loc = op->getLoc(); + Value message_constant = createOrFindGlobalStringConstant( + loc, getGlobalName("debug_op", debug_str), debug_str, b); + + // Insert function call. + MLIRContext *ctx = op->getContext(); + auto func_type = LLVM::LLVMFunctionType::get( + LLVM::LLVMVoidType::get(op->getContext()), + {LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8))}); + FlatSymbolRefAttr tf_func_ref = + getOrInsertLLVMFunction(kPrintStringFuncName, func_type, op, b); + b->create(loc, llvm::None, tf_func_ref, + llvm::makeArrayRef({message_constant})); +} + +void mlir::memref::embedMemRefPrints(ModuleOp module, + ExtractValuesFunc extract_values_func) { + module.walk([&](FuncOp func) { + if (func.isDeclaration()) + return; + Block *body = &func.getBody().front(); + + // Print arguments. + OpBuilder b(module.getContext()); + b.setInsertionPointToStart(body); + Location loc = func.getLoc(); + auto args = func.getArguments(); + if (!args.empty()) + emitOperationPrint(func, &b); + for (auto arg : args) + emitValuePrint(loc, arg, &b); + + // Print values extracted from every op. + if (!extract_values_func) + return; + for (auto &op : func.getBody().front().getOperations()) { + b.setInsertionPointAfter(&op); + auto values = extract_values_func(&op); + if (!values.empty()) + emitOperationPrint(&op, &b); + for (Value value : values) + emitValuePrint(op.getLoc(), value, &b); + } + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h --- a/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/MemRef/Transforms/PassDetail.h @@ -27,6 +27,10 @@ class FuncDialect; } // namespace func +namespace LLVM { +class LLVMDialect; +} // namespace LLVM + namespace memref { class MemRefDialect; } // namespace memref diff --git a/mlir/test/Transforms/embed-memref-prints.mlir b/mlir/test/Transforms/embed-memref-prints.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/embed-memref-prints.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s --test-embed-memref-prints | FileCheck %s + +func.func @print_memrefs(%input: memref<*xf32>) -> memref<*xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %rank = memref.rank %input : memref<*xf32> + %shape = memref.alloca(%rank) : memref + scf.for %i = %c0 to %rank step %c1 { + %dim = memref.dim %input, %i : memref<*xf32> + memref.store %dim, %shape[%i] : memref + } + + %c9000 = arith.constant 9000 : index + %num_elem = memref.alloca() : memref<1xindex> + memref.store %c9000, %num_elem[%c0] : memref<1xindex> + %flat_input = memref.reshape %input(%num_elem) + : (memref<*xf32>, memref<1xindex>) -> memref + + %flat_output = memref.alloc(%c9000) : memref + %output = memref.reshape %flat_output(%shape) + : (memref, memref) -> memref<*xf32> + func.return %output : memref<*xf32> +} + +// CHECK-DAG: global internal constant @[[STR0:debug_op_[0-9]+]]({{.*}} @print_memrefs +// CHECK-DAG: global internal constant @[[STR1:debug_op_[0-9]+]]({{.*}} -> memref +// CHECK-DAG: global internal constant @[[STR2:debug_op_[0-9]+]]({{.*}} -> memref<*xf32> +// CHECK-DAG: func private @print_memref_f32(memref<*xf32>) +// CHECK-DAG: llvm.func @print_c_string(!llvm.ptr) + +// CHECK: func @print_memrefs +// CHECK-SAME: (%[[ARG:.*]]: memref<*xf32>) +// Print debug info for the function arg. +// CHECK: %[[STR0_ADDR:.*]] = llvm.mlir.addressof @[[STR0]] +// CHECK: %[[STR0_PTR:.*]] = llvm.getelementptr %[[STR0_ADDR]] +// CHECK: llvm.call @print_c_string(%[[STR0_PTR]]) : (!llvm.ptr) +// CHECK: call @print_memref_f32(%[[ARG]]) : (memref<*xf32>) -> () + +// Print debug info for reshape from unranked to ranked. +// CHECK: %[[RESHAPE:.*]] = memref.reshape %[[ARG]] +// CHECK: %[[STR1_ADDR:.*]] = llvm.mlir.addressof @[[STR1]] +// CHECK: %[[STR1_PTR:.*]] = llvm.getelementptr %[[STR1_ADDR]] +// CHECK: llvm.call @print_c_string(%[[STR1_PTR]]) : (!llvm.ptr) +// CHECK: %[[UNRANKED_BUF:.*]] = memref.cast %[[RESHAPE]] +// CHECK: call @print_memref_f32(%[[UNRANKED_BUF]]) : (memref<*xf32>) + +// Print debug info for reshape from ranked to unranked. +// CHECK: %[[ALLOC:.*]] = memref.alloc +// CHECK: %[[RESHAPE_2:.*]] = memref.reshape %[[ALLOC]] +// CHECK: %[[STR2_ADDR:.*]] = llvm.mlir.addressof @[[STR2]] +// CHECK: %[[STR2_PTR:.*]] = llvm.getelementptr %[[STR2_ADDR]] +// CHECK: llvm.call @print_c_string(%[[STR2_PTR]]) : (!llvm.ptr) +// CHECK: call @print_memref_f32(%[[RESHAPE_2]]) : (memref<*xf32>) diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMemRefTestPasses + TestEmbedMemRefPrints.cpp TestComposeSubView.cpp TestMultiBuffer.cpp @@ -7,6 +8,7 @@ LINK_LIBS PUBLIC MLIRPass + MLIRLLVMIR MLIRMemRef MLIRMemRefTransforms MLIRTestDialect diff --git a/mlir/test/lib/Dialect/MemRef/TestEmbedMemRefPrints.cpp b/mlir/test/lib/Dialect/MemRef/TestEmbedMemRefPrints.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestEmbedMemRefPrints.cpp @@ -0,0 +1,56 @@ +//===- TestEmbedMemRefPrints.cpp - Test embedded memref prints ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test embedded CallOps to print memref content. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" + +using namespace mlir; + +static SmallVector extractValuesToPrint(Operation *op) { + if (isa(op) || isa(op) || + isa(op) || isa(op)) + return {op->getResult(0)}; + if (auto copy = dyn_cast(op)) + return {copy.target()}; + return {}; +} + +namespace { + +struct TestEmbedMemRefPrintsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmbedMemRefPrintsPass) + + StringRef getArgument() const final { return "test-embed-memref-prints"; } + StringRef getDescription() const final { + return "Test embedding memref prints"; + } + + void getDependentDialects(DialectRegistry ®istry) const { + registry.insert(); + } + + void runOnOperation() override { + memref::embedMemRefPrints(getOperation(), extractValuesToPrint); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestEmbedMemRefPrints() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/test/mlir-cpu-runner/debug_memref_content.mlir b/mlir/test/mlir-cpu-runner/debug_memref_content.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/debug_memref_content.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -pass-pipeline="test-embed-memref-prints,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-math-to-llvm,convert-memref-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts" |\ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext |\ +// RUN: FileCheck %s + +#map = affine_map<(d0) -> (d0)> +func.func @main() -> () { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // Initialize input. + %input = memref.alloca() : memref<6xf32> + %dim = memref.dim %input, %c0 : memref<6xf32> + scf.for %i = %c0 to %dim step %c1 { + %diff = arith.subi %c1, %i : index + %diff_i64 = arith.index_cast %diff : index to i64 + %diff_f32 = arith.sitofp %diff_i64 : i64 to f32 + memref.store %diff_f32, %input[%i] : memref<6xf32> + } + + %output = func.call @reshape(%input) + : (memref<6xf32>) -> (memref<2x3xf32>) + func.return +} + + +func.func @reshape(%in: memref<6xf32>) -> memref<2x3xf32> { + %out = memref.expand_shape %in [[0, 1]]: memref<6xf32> into memref<2x3xf32> + func.return %out : memref<2x3xf32> +} +// CHECK: Print memref content after the following op +// CHECK: @reshape +// CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] +// CHECK: [1, 0, -1, -2, -3, -4] + +// CHECK: Print memref content after the following op +// CHECK: memref.expand_shape +// CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] +// CHECK: [1, 0, -1] +// CHECK: [-2, -3, -4] diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -75,6 +75,7 @@ void registerTestDynamicPipelinePass(); void registerTestExpandTanhPass(); void registerTestComposeSubView(); +void registerTestEmbedMemRefPrints(); void registerTestMultiBuffering(); void registerTestGpuParallelLoopMappingPass(); void registerTestIRVisitorsPass(); @@ -164,6 +165,7 @@ mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); + mlir::test::registerTestEmbedMemRefPrints(); mlir::test::registerTestExpandTanhPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8099,6 +8099,7 @@ ":FuncDialect", ":IR", ":InferTypeOpInterface", + ":LLVMDialect", ":LoopLikeInterface", ":MemRefDialect", ":MemRefPassIncGen", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -441,6 +441,7 @@ deps = [ ":TestDialect", "//mlir:Affine", + "//mlir:LLVMDialect", "//mlir:MemRefDialect", "//mlir:MemRefTransforms", "//mlir:Pass",