diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -473,4 +473,47 @@ let hasCanonicalizer = 1; } +def Bufferization_DeallocOp : Bufferization_Op<"dealloc", [ + AttrSizedOperandSegments, DeclareOpInterfaceMethods + ]> { + let summary = "deallocates the given memrefs if no alias is retained"; + let description = [{ + This operation deallocates each of the given memrefs if there is no alias + to that memref in the list of retained memrefs and the corresponding + condition value is set. This condition can be used to indicate and pass on + ownership of memref values (or in other words, the responsibility of + deallocating that memref). If two memrefs alias each other, only one will be + deallocated to avoid double free situations. + + The memrefs to be deallocated must be the originally allocated memrefs, + however, the memrefs to be retained may be arbitrary memrefs. + + Returns a list of conditions corresponding to the list of memrefs which + indicates the new ownerships, i.e., if the memref was deallocated the + ownership was dropped (set to 'false') and otherwise will be the same as the + input condition. + + Example: + ```mlir + %0:2 = bufferization.dealloc %a0, %a1 if %cond0, %cond1 retain %r0, %r1 : + memref<2xf32>, memref<4xi32> retain memref, memref + ``` + Deallocation will be called on `%a0` if `%cond0` is 'true' and neither `%r0` + or `%r1` are aliases of `%a0`. `%a1` will be deallocated when `%cond1` is + set to 'true' and none of `%r0`, %r1` and `%a0` are aliases. + }]; + + let arguments = (ins Variadic:$memrefs, + Variadic:$conditions, + Variadic:$retained); + let results = (outs Variadic:$updatedConditions); + + let assemblyFormat = [{ + (` ``(` $memrefs^ `:` type($memrefs) `)` `if` ` ` `(` $conditions `)` )? + (`retain` ` ` `(` $retained^ `:` type($retained) `)` )? attr-dict + }]; + + let hasVerifier = 1; +} + #endif // BUFFERIZATION_OPS diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -746,6 +746,26 @@ return builder.create(alloc.getLoc(), alloc).getResult(); } +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// + +LogicalResult DeallocOp::inferReturnTypes( + MLIRContext *context, std::optional<::mlir::Location> location, + ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, + RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + DeallocOpAdaptor adaptor(operands, attributes, properties, regions); + inferredReturnTypes = SmallVector(adaptor.getConditions().getTypes()); + return success(); +} + +LogicalResult DeallocOp::verify() { + if (getMemrefs().size() != getConditions().size()) + return emitOpError( + "must have the same number of conditions as memrefs to deallocate"); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/inlining.mlir b/mlir/test/Dialect/Bufferization/inlining.mlir --- a/mlir/test/Dialect/Bufferization/inlining.mlir +++ b/mlir/test/Dialect/Bufferization/inlining.mlir @@ -4,13 +4,16 @@ // CHECK-SAME: (%[[ARG:.*]]: memref<*xf32>) // CHECK-NOT: call // CHECK: %[[RES:.*]] = bufferization.clone %[[ARG]] +// CHECK: bufferization.dealloc // CHECK: return %[[RES]] -func.func @test_inline(%buf : memref<*xf32>) -> memref<*xf32> { - %0 = call @inner_func(%buf) : (memref<*xf32>) -> memref<*xf32> - return %0 : memref<*xf32> +func.func @test_inline(%buf : memref<*xf32>) -> (memref<*xf32>, i1) { + %0:2 = call @inner_func(%buf) : (memref<*xf32>) -> (memref<*xf32>, i1) + return %0#0, %0#1 : memref<*xf32>, i1 } -func.func @inner_func(%buf : memref<*xf32>) -> memref<*xf32> { +func.func @inner_func(%buf : memref<*xf32>) -> (memref<*xf32>, i1) { + %true = arith.constant true %clone = bufferization.clone %buf : memref<*xf32> to memref<*xf32> - return %clone : memref<*xf32> + %0 = bufferization.dealloc (%buf : memref<*xf32>) if (%true) retain (%clone : memref<*xf32>) + return %clone, %0 : memref<*xf32>, i1 } diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -103,3 +103,11 @@ // expected-error @below{{expects different type than prior uses: 'tensor' vs 'tensor<5xf32>'}} bufferization.copy_tensor %arg0, %arg1 : tensor } + +// ----- + +func.func @invalid_dealloc_memref_condition_mismatch(%arg0: memref<2xf32>, %arg1: memref<4xi32>, %arg2: i1) -> i1 { + // expected-error @below{{must have the same number of conditions as memrefs to deallocate}} + %0 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2) + return %0 : i1 +} diff --git a/mlir/test/Dialect/Bufferization/ops.mlir b/mlir/test/Dialect/Bufferization/ops.mlir --- a/mlir/test/Dialect/Bufferization/ops.mlir +++ b/mlir/test/Dialect/Bufferization/ops.mlir @@ -65,3 +65,16 @@ %1 = bufferization.copy_tensor %arg0, %arg1 : tensor return %1 : tensor } + +// CHECK-LABEL: func @test_dealloc_op +func.func @test_dealloc_op(%arg0: memref<2xf32>, %arg1: memref<4xi32>, + %arg2: i1, %arg3: i1, %arg4: memref, + %arg5: memref<*xf64>) -> (i1, i1) { + // CHECK: bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref, memref<*xf64>) + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<4xi32>) if (%arg2, %arg3) retain (%arg4, %arg5 : memref, memref<*xf64>) + // CHECK: bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) + %1 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) + // CHECK: bufferization.dealloc + bufferization.dealloc + return %0, %1 : i1, i1 +}