diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -246,12 +246,16 @@ ``` }]; - let arguments = (ins MemRefRankOf<[AnyType], [1]>:$source, + // Note that we conceptually mark the operands as freeing the incoming + // memref and allocating the outcoming memref, even though this may not + // physically happen on each execution. + + let arguments = (ins Arg, "", [MemFree]>:$source, Optional:$dynamicResultSize, ConfinedAttr, [IntMinValue<0>]>:$alignment); - let results = (outs MemRefRankOf<[AnyType], [1]>); + let results = (outs Res, "", [MemAlloc]>); let builders = [ OpBuilder<(ins "MemRefType":$resultType, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -639,6 +639,16 @@ } }; +struct DefaultReallocationInterface + : public bufferization::AllocationOpInterface::ExternalModel< + DefaultAllocationInterface, memref::ReallocOp> { + static std::optional buildDealloc(OpBuilder &builder, + Value realloc) { + return builder.create(realloc.getLoc(), realloc) + .getOperation(); + } +}; + /// The actual buffer deallocation pass that inserts and moves dealloc nodes /// into the right positions. Furthermore, it inserts additional clones if /// necessary. It uses the algorithm described at the top of the file. @@ -703,6 +713,7 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { memref::AllocOp::attachInterface(*ctx); + memref::ReallocOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation.mlir @@ -1418,3 +1418,24 @@ // CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<10xf32> // CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<10xf32> // CHECK-NEXT: affine.if + +// ----- + +// Ensure we free the realloc, not the alloc. + +// CHECK-LABEL: func @auto_dealloc() +func.func @auto_dealloc() { + %c10 = arith.constant 10 : index + %c100 = arith.constant 100 : index + %alloc = memref.alloc(%c10) : memref + %realloc = memref.realloc %alloc(%c100) : memref to memref + return +} +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index +// CHECK-DAG: %[[C100:.*]] = arith.constant 100 : index +// CHECK-NEXT: %[[A:.*]] = memref.alloc(%[[C10]]) : memref +// CHECK-NEXT: %[[R:.*]] = memref.realloc %alloc(%[[C100]]) : memref to memref +// CHECK-NEXT: memref.dealloc %[[R]] : memref +// CHECK-NEXT: return + +