diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -58,6 +58,9 @@ /*default=*/"false", "(Experimental) Try to eliminate init_tensor operations that are " "anchored at an insert_slice op">, + Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true", + "Specify if buffers should be deallocated. For compatibility with " + "core bufferization passes.">, ]; let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -92,6 +92,7 @@ options->analysisFuzzerSeed = analysisFuzzerSeed; options->testAnalysisOnly = testAnalysisOnly; options->printConflicts = printConflicts; + options->createDeallocs = createDeallocs; // Enable InitTensorOp elimination. if (initTensorElimination) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize-compat.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize-compat.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize-compat.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s \ +// RUN: -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops create-deallocs=0" \ +// RUN: -split-input-file | \ +// RUN: FileCheck %s --check-prefix=CHECK-NODEALLOC + +// RUN: mlir-opt %s \ +// RUN: -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops create-deallocs=0" \ +// RUN: -buffer-deallocation | \ +// RUN: FileCheck %s --check-prefix=CHECK-BUFFERDEALLOC + +// CHECK-NODEALLOC-LABEL: func @out_of_place_bufferization +// CHECK-BUFFERDEALLOC-LABEL: func @out_of_place_bufferization +func @out_of_place_bufferization(%t1 : tensor) -> (f32, f32) { + // CHECK-NODEALLOC: memref.alloc + // CHECK-NODEALLOC: memref.copy + // CHECK-NODEALLOC-NOT: memref.dealloc + + // CHECK-BUFFERDEALLOC: %[[alloc:.*]] = memref.alloc + // CHECK-BUFFERDEALLOC: memref.copy + // CHECK-BUFFERDEALLOC: memref.dealloc %[[alloc]] + + %cst = arith.constant 0.0 : f32 + %idx = arith.constant 5 : index + + // This bufferizes out-of-place. An allocation + copy will be inserted. + %0 = tensor.insert %cst into %t1[%idx] : tensor + + %1 = tensor.extract %t1[%idx] : tensor + %2 = tensor.extract %0[%idx] : tensor + return %1, %2 : f32, f32 +} diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -92,6 +92,10 @@ *this, "dialect-filter", llvm::cl::desc("Bufferize only ops from the specified dialects"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + Option createDeallocs{ + *this, "create-deallocs", + llvm::cl::desc("Specify if buffers should be deallocated"), + llvm::cl::init(true)}; }; } // namespace @@ -105,6 +109,7 @@ options->allowUnknownOps = allowUnknownOps; options->testAnalysisOnly = testAnalysisOnly; options->analysisFuzzerSeed = analysisFuzzerSeed; + options->createDeallocs = createDeallocs; if (dialectFilter.hasValue()) { options->dialectFilter.emplace();