diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -7,12 +7,29 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Transforms/InliningUtils.h" using namespace mlir; using namespace mlir::bufferization; #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc" +//===----------------------------------------------------------------------===// +// Bufferization Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct BufferizationInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + /// Operations in Bufferization dialect are always legal to inline. + bool isLegalToInline(Operation *, Region *, bool, + BlockAndValueMapping &) const final { + return true; + } +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // Bufferization Dialect //===----------------------------------------------------------------------===// @@ -22,4 +39,5 @@ #define GET_OP_LIST #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" >(); + addInterfaces(); } diff --git a/mlir/test/Dialect/Bufferization/inlining.mlir b/mlir/test/Dialect/Bufferization/inlining.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/inlining.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// CHECK-LABEL: func @test_inline +// CHECK-SAME: (%[[ARG:.*]]: memref<*xf32>) +// CHECK-NOT: call +// CHECK: %[[RES:.*]] = bufferization.clone %[[ARG]] +// CHECK: return %[[RES]] +func @test_inline(%buf : memref<*xf32>) -> memref<*xf32> { + %0 = call @inner_func(%buf) : (memref<*xf32>) -> memref<*xf32> + return %0 : memref<*xf32> +} + +func @inner_func(%buf : memref<*xf32>) -> memref<*xf32> { + %clone = bufferization.clone %buf : memref<*xf32> to memref<*xf32> + return %clone : memref<*xf32> +}