diff --git a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h --- a/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h +++ b/mlir/include/mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h @@ -18,7 +18,7 @@ #define GEN_PASS_DECL_CONVERTBUFFERIZATIONTOMEMREF #include "mlir/Conversion/Passes.h.inc" -std::unique_ptr> createBufferizationToMemRefPass(); +std::unique_ptr createBufferizationToMemRefPass(); } // namespace mlir #endif // MLIR_CONVERSION_BUFFERIZATIONTOMEMREF_BUFFERIZATIONTOMEMREF_H diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -187,8 +187,7 @@ // BufferizationToMemRef //===----------------------------------------------------------------------===// -def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref", - "mlir::ModuleOp"> { +def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> { let summary = "Convert operations from the Bufferization dialect to the " "MemRef dialect"; let description = [{ diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -409,6 +409,11 @@ if (adaptor.getMemrefs().size() == 1) return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter); + if (!deallocHelperFunc) + return op->emitError( + "library function required for generic lowering, but cannot be " + "automatically inserted when operating on functions"); + return rewriteGeneralCase(op, adaptor, rewriter); } @@ -620,21 +625,29 @@ BufferizationToMemRefPass() = default; void runOnOperation() override { - ModuleOp module = cast(getOperation()); - OpBuilder builder = - OpBuilder::atBlockBegin(&module.getBodyRegion().front()); - SymbolTable symbolTable(module); + if (!isa(getOperation())) { + emitError(getOperation()->getLoc(), + "root operation must be a builtin.module or a function"); + signalPassFailure(); + return; + } - // Build dealloc helper function if there are deallocs. func::FuncOp helperFuncOp; - getOperation()->walk([&](bufferization::DeallocOp deallocOp) { - if (deallocOp.getMemrefs().size() > 1) { - helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction( - builder, getOperation()->getLoc(), symbolTable); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); + if (auto module = dyn_cast(getOperation())) { + OpBuilder builder = + OpBuilder::atBlockBegin(&module.getBodyRegion().front()); + SymbolTable symbolTable(module); + + // Build dealloc helper function if there are deallocs. + getOperation()->walk([&](bufferization::DeallocOp deallocOp) { + if (deallocOp.getMemrefs().size() > 1) { + helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction( + builder, getOperation()->getLoc(), symbolTable); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + } RewritePatternSet patterns(&getContext()); patterns.add(patterns.getContext()); @@ -652,7 +665,6 @@ }; } // namespace -std::unique_ptr> -mlir::createBufferizationToMemRefPass() { +std::unique_ptr mlir::createBufferizationToMemRefPass() { return std::make_unique(); } diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref-func.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt -verify-diagnostics --pass-pipeline="builtin.module(func.func(convert-bufferization-to-memref))" -split-input-file %s | FileCheck %s + +// CHECK-NOT: func @deallocHelper +// CHECK-LABEL: func @conversion_dealloc_simple +// CHECK-SAME: [[ARG0:%.+]]: memref<2xf32> +// CHECK-SAME: [[ARG1:%.+]]: i1 +func.func @conversion_dealloc_simple(%arg0: memref<2xf32>, %arg1: i1) { + bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg1) + return +} + +// CHECk: scf.if [[ARG1]] { +// CHECk-NEXT: memref.dealloc [[ARG0]] : memref<2xf32> +// CHECk-NEXT: } +// CHECk-NEXT: return + +// ----- + +func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) { + // expected-error @below {{library function required for generic lowering, but cannot be automatically inserted when operating on functions}} + // expected-error @below {{failed to legalize operation 'bufferization.dealloc' that was explicitly marked illegal}} + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>) + return %0#0, %0#1 : i1, i1 +}