diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td @@ -28,7 +28,8 @@ deallocation](/docs/BufferDeallocationInternals/). }]; let dependentDialects = [ - "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" + "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect", + "arith::ArithDialect" ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -514,6 +514,7 @@ }]; let hasVerifier = 1; + let hasCanonicalizer = 1; } #endif // BUFFERIZATION_OPS diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -766,6 +766,69 @@ return success(); } +namespace { + +struct DeallocRemoveDuplicates : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + // Unique memrefs to be deallocated. + DenseMap memrefToCondition; + SmallVector newMemrefs, newConditions, newRetained; + SmallVector resultIndices; + for (auto [memref, cond] : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + if (memrefToCondition.count(memref)) { + // If the dealloc conditions don't match, we need to make sure that the + // dealloc happens on the union of cases. + auto &newCond = newConditions[memrefToCondition[memref]]; + if (newCond != cond) + newCond = + rewriter.create(deallocOp.getLoc(), newCond, cond); + } else { + memrefToCondition.insert({memref, newConditions.size()}); + newMemrefs.push_back(memref); + newConditions.push_back(cond); + } + resultIndices.push_back(memrefToCondition[memref]); + } + + // Unique retained values + DenseSet seen; + for (auto retained : deallocOp.getRetained()) { + if (!seen.contains(retained)) { + seen.insert(retained); + newRetained.push_back(retained); + } + } + + // Return failure if we don't change anything such that we don't run into an + // infinite loop of pattern applications. + if (newConditions.size() == deallocOp.getConditions().size() && + newRetained.size() == deallocOp.getRetained().size()) + return failure(); + + // We need to create a new op because the number of results is always the + // same as the number of condition operands. + auto newDealloc = rewriter.create(deallocOp.getLoc(), newMemrefs, + newConditions, newRetained); + for (auto [i, newIdx] : llvm::enumerate(resultIndices)) + rewriter.replaceAllUsesWith(deallocOp.getResult(i), + newDealloc.getResult(newIdx)); + + rewriter.eraseOp(deallocOp); + return success(); + } +}; + +} // anonymous namespace + +void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect + MLIRArithDialect MLIRDestinationStyleOpInterface MLIRDialect MLIRFuncDialect diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -279,3 +279,17 @@ // CHECK-NOT: bufferization.clone // CHECK-NOT: memref.dealloc // CHECK: return {{.*}} + +// ----- + +func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>) { + %0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg4, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>) + %1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2) + return +} + +// CHECK-LABEL: func @dealloc_canonicalize_duplicates +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>) +// CHECK-NEXT: bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG4]] : memref<2xi32>, memref<2xi32>) +// CHECK-NEXT: [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1 +// CHECK-NEXT: bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])