diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -24,6 +24,11 @@ /// buffers. std::unique_ptr createBufferDeallocationPass(); +/// Creates a pass that optimizes `bufferization.dealloc` operations. For +/// example, it reduces the number of alias checks needed at runtime using +/// static alias analysis. +std::unique_ptr createBufferDeallocationSimplificationPass(); + /// Run buffer deallocation. LogicalResult deallocateBuffers(Operation *op); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -88,6 +88,26 @@ let constructor = "mlir::bufferization::createBufferDeallocationPass()"; } +def BufferDeallocationSimplification : + Pass<"buffer-deallocation-simplification", "func::FuncOp"> { + let summary = "Optimizes `bufferization.dealloc` operation for more " + "efficient codegen"; + let description = [{ + This pass uses static alias analysis to reduce the number of alias checks + required at runtime. Such checks are sometimes necessary to make sure that + memrefs aren't deallocated before their last usage (use after free) or that + some memref isn't deallocated twice (double free). + }]; + + let constructor = + "mlir::bufferization::createBufferDeallocationSimplificationPass()"; + + let dependentDialects = [ + "mlir::bufferization::BufferizationDialect", "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect" + ]; +} + def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> { let summary = "Optimizes placement of allocation operations by moving them " "into common dominators and out of nested regions"; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -869,57 +869,6 @@ } }; -/// Remove memrefs to be deallocated that are also present in the retained list -/// since they will always alias and thus never actually be deallocated. -/// Example: -/// ```mlir -/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...) -/// ``` -/// is canonicalized to -/// ```mlir -/// %0 = bufferization.dealloc retain (%arg0 : ...) -/// ``` -struct DeallocRemoveDeallocMemrefsContainedInRetained - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DeallocOp deallocOp, - PatternRewriter &rewriter) const override { - // Unique memrefs to be deallocated. - DenseMap retained; - for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained())) - retained[ret] = i; - - // There must not be any duplicates in the retain list anymore because we - // would miss updating one of the result values otherwise. - if (retained.size() != deallocOp.getRetained().size()) - return failure(); - - SmallVector newMemrefs, newConditions; - for (auto [memref, cond] : - llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { - if (retained.contains(memref)) { - rewriter.setInsertionPointAfter(deallocOp); - auto orOp = rewriter.create( - deallocOp.getLoc(), - deallocOp.getUpdatedConditions()[retained[memref]], cond); - rewriter.replaceAllUsesExcept( - deallocOp.getUpdatedConditions()[retained[memref]], - orOp.getResult(), orOp); - continue; - } - - newMemrefs.push_back(memref); - newConditions.push_back(cond); - } - - // Return failure if we don't change anything such that we don't run into an - // infinite loop of pattern applications. - return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, - rewriter); - } -}; - /// Erase deallocation operations where the variadic list of memrefs to /// deallocate is empty. Example: /// ```mlir diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -0,0 +1,188 @@ +//===- BufferDeallocationSimplification.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for optimizing `bufferization.dealloc` operations +// that requires more analysis than what can be supported by regular +// canonicalization patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace bufferization { +#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION +#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" +} // namespace bufferization +} // namespace mlir + +using namespace mlir; +using namespace mlir::bufferization; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, + ValueRange memrefs, + ValueRange conditions, + PatternRewriter &rewriter) { + if (deallocOp.getMemrefs() == memrefs && + deallocOp.getConditions() == conditions) + return failure(); + + rewriter.updateRootInPlace(deallocOp, [&]() { + deallocOp.getMemrefsMutable().assign(memrefs); + deallocOp.getConditionsMutable().assign(conditions); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Remove values from the `memref` operand list that are also present in the +/// `retained` list since they will always alias and thus never actually be +/// deallocated. However, we also need to be certain that no other value in the +/// `retained` list can alias, for which we use a static alias analysis. This is +/// necessary because the `dealloc` operation is defined to return one `i1` +/// value per memref in the `retained` list which represents the disjunction of +/// the condition values corresponding to all aliasing values in the `memref` +/// list. In particular, this means that if there is some value R in the +/// `retained` list which aliases with a value M in the `memref` list (but can +/// only be staticaly determined to may-alias) and M is also present in the +/// `retained` list, then it would be illegal to remove M because the result +/// corresponding to R would be computed incorrectly afterwards. +/// Because we require an alias analysis, this pattern cannot be applied as a +/// regular canonicalization pattern. +/// +/// Example: +/// ```mlir +/// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0) +/// retain (%m0, %r0, %r1 : ...) +/// ``` +/// is canonicalized to +/// ```mlir +/// // bufferization.dealloc without memrefs and conditions returns %false for +/// // every retained value +/// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...) +/// %1 = arith.ori %0#0, %cond0 : i1 +/// // replace %0#0 with %1 +/// ``` +/// given that `%r0` and `%r1` may not alias with `%m0`. +struct DeallocRemoveDeallocMemrefsContainedInRetained + : public OpRewritePattern { + DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context, + AliasAnalysis &aliasAnalysis) + : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + // Unique memrefs to be deallocated. + DenseMap retained; + for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained())) + retained[ret] = i; + + // There must not be any duplicates in the retain list anymore because we + // would miss updating one of the result values otherwise. + if (retained.size() != deallocOp.getRetained().size()) + return failure(); + + SmallVector newMemrefs, newConditions; + for (auto memrefAndCond : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + Value memref = std::get<0>(memrefAndCond); + Value cond = std::get<1>(memrefAndCond); + + auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool { + Value retainedMemref = deallocOp.getRetained()[retained[memref]]; + // The current memref must not have a may-alias relation to any retained + // memref, and exactly one must-alias relation. + // TODO: it is possible to extend this pattern to allow an arbitrary + // number of must-alias relations as long as there is no may-alias. If + // it's no-alias, then just proceed (only supported case as of now), if + // it's must-alias, we also need to update the condition for that alias. + if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) { + return aliasAnalysis.alias(mr, memref).isNo() || + mr == retainedMemref; + })) { + rewriter.setInsertionPointAfter(deallocOp); + auto orOp = rewriter.create( + deallocOp.getLoc(), + deallocOp.getUpdatedConditions()[retained[memref]], cond); + rewriter.replaceAllUsesExcept( + deallocOp.getUpdatedConditions()[retained[memref]], + orOp.getResult(), orOp); + return true; + } + return false; + }; + + if (retained.contains(memref) && + replaceResultsIfNoInvalidAliasing(memref)) + continue; + + auto extractOp = memref.getDefiningOp(); + if (extractOp && retained.contains(extractOp.getOperand()) && + replaceResultsIfNoInvalidAliasing(extractOp.getOperand())) + continue; + + newMemrefs.push_back(memref); + newConditions.push_back(cond); + } + + // Return failure if we don't change anything such that we don't run into an + // infinite loop of pattern applications. + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, + rewriter); + } + +private: + AliasAnalysis &aliasAnalysis; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// BufferDeallocationSimplificationPass +//===----------------------------------------------------------------------===// + +namespace { + +/// The actual buffer deallocation pass that inserts and moves dealloc nodes +/// into the right positions. Furthermore, it inserts additional clones if +/// necessary. It uses the algorithm described at the top of the file. +struct BufferDeallocationSimplificationPass + : public bufferization::impl::BufferDeallocationSimplificationBase< + BufferDeallocationSimplificationPass> { + void runOnOperation() override { + AliasAnalysis &aliasAnalysis = getAnalysis(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), + aliasAnalysis); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr +mlir::bufferization::createBufferDeallocationSimplificationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms Bufferize.cpp BufferDeallocation.cpp + BufferDeallocationSimplification.cpp BufferOptimizations.cpp BufferResultsToOutParams.cpp BufferUtils.cpp diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s + +func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) { + %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) + %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) + %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) + return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 +} + +// CHECK-LABEL: func @dealloc_deallocated_in_retained +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]] +// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] +// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) +// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 : + +// ----- + +func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) { + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref, index, index, index + %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref, index, index, index + %0 = bufferization.dealloc (%base_buffer : memref) if (%arg1) retain (%arg0 : memref<2xi32>) + %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref, memref) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) + %2:2 = bufferization.dealloc (%base_buffer : memref) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) + return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 +} + +// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] : +// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] : +// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]] +// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] +// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) +// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 : diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -297,19 +297,16 @@ // ----- -func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) { - %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) - %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) +func.func @dealloc_erase_empty(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> i1 { bufferization.dealloc - bufferization.dealloc retain (%arg0 : memref<2xi32>) - return %0, %1 : i1, i1 + %0 = bufferization.dealloc retain (%arg0 : memref<2xi32>) + return %0 : i1 } -// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated +// CHECK-LABEL: func @dealloc_erase_empty // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) -// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) -// CHECK-NEXT: [[V1:%.+]] = arith.ori [[V0]], [[ARG1]] -// CHECK-NEXT: return [[ARG1]], [[V1]] : +// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false +// CHECK-NEXT: return [[FALSE]] : // -----