diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -866,11 +866,60 @@ } }; +/// Removes memrefs from the deallocation list if their associated condition is +/// always 'false'. +/// +/// Example: +/// ``` +/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) +/// if (%arg2, %false) +/// ``` +/// becomes +/// ``` +/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2) +/// ``` +struct EraseAlwaysFalseDealloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + SmallVector newMemrefs, newConditions; + SmallVector replacements; + + for (auto [res, memref, cond] : + llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(), + deallocOp.getConditions())) { + if (matchPattern(cond, m_Zero())) { + replacements.push_back(cond); + continue; + } + newMemrefs.push_back(memref); + newConditions.push_back(cond); + replacements.push_back({}); + } + + if (newMemrefs.size() == deallocOp.getMemrefs().size()) + return failure(); + + auto newDeallocOp = rewriter.create( + deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained()); + unsigned i = 0; + for (auto &repl : replacements) + if (!repl) + repl = newDeallocOp.getResult(i++); + + rewriter.replaceOp(deallocOp, replacements); + return success(); + } +}; + } // anonymous namespace void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -309,3 +309,17 @@ // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) // CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) // CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] : + +// ----- + +func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) { + %false = arith.constant false + %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2) + return %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: func @dealloc_always_false_condition +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1) +// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false +// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]]) +// CHECK-NEXT: return [[FALSE]], [[V0]] :