diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -791,11 +791,19 @@ LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { // Unique memrefs to be deallocated. + DenseSet retained(deallocOp.getRetained().begin(), + deallocOp.getRetained().end()); DenseMap memrefToCondition; SmallVector newMemrefs, newConditions, newRetained; - SmallVector resultIndices; - for (auto [memref, cond] : - llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + SmallVector resultIndices(deallocOp.getMemrefs().size(), -1); + for (auto [i, memref, cond] : + llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) { + if (retained.contains(memref)) { + rewriter.replaceAllUsesWith(deallocOp.getResult(i), + deallocOp.getConditions()[i]); + continue; + } + if (memrefToCondition.count(memref)) { // If the dealloc conditions don't match, we need to make sure that the // dealloc happens on the union of cases. @@ -808,7 +816,7 @@ newMemrefs.push_back(memref); newConditions.push_back(cond); } - resultIndices.push_back(memrefToCondition[memref]); + resultIndices[i] = memrefToCondition[memref]; } // Unique retained values @@ -831,19 +839,38 @@ auto newDealloc = rewriter.create(deallocOp.getLoc(), newMemrefs, newConditions, newRetained); for (auto [i, newIdx] : llvm::enumerate(resultIndices)) - rewriter.replaceAllUsesWith(deallocOp.getResult(i), - newDealloc.getResult(newIdx)); + if (newIdx != -1) + rewriter.replaceAllUsesWith(deallocOp.getResult(i), + newDealloc.getResult(newIdx)); rewriter.eraseOp(deallocOp); return success(); } }; +/// Erase deallocation operations where the variadic list of memrefs to +/// deallocate is emtpy. Example: +/// ```mlir +/// bufferization.dealloc retain (%arg0: memref<2xi32>) +/// ``` +struct EraseEmptyDealloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + if (deallocOp.getMemrefs().empty()) { + rewriter.eraseOp(deallocOp); + return success(); + } + return failure(); + } +}; + } // anonymous namespace void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -282,15 +282,30 @@ // ----- -func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>) -> (i1, i1, i1, i1, i1) { - %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg4, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>) +func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1, i1, i1) { + %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>) %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2) return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1 } // CHECK-LABEL: func @dealloc_canonicalize_duplicates -// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>) -// CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG4]] : memref<2xi32>, memref<2xi32>) +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>) // CHECK-NEXT: [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1 // CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]]) // CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] : + +// ----- + +func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1) { + %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) + %1:2 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) + bufferization.dealloc + bufferization.dealloc retain (%arg0 : memref<2xi32>) + return %0, %1#0, %1#1 : i1, i1, i1 +} + +// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :