diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h @@ -24,6 +24,11 @@ /// buffers. std::unique_ptr createBufferDeallocationPass(); +/// Creates a pass that optimizes `bufferization.dealloc` operations. For +/// example, it reduces the number of alias checks needed at runtime using +/// static alias analysis. +std::unique_ptr createBufferDeallocationSimplificationPass(); + /// Run buffer deallocation. LogicalResult deallocateBuffers(Operation *op); diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -88,6 +88,26 @@ let constructor = "mlir::bufferization::createBufferDeallocationPass()"; } +def BufferDeallocationSimplification : + Pass<"buffer-deallocation-simplification", "func::FuncOp"> { + let summary = "Optimizes `bufferization.dealloc` operation for more " + "efficient codegen"; + let description = [{ + This pass uses static alias analysis to reduce the number of alias checks + required at runtime. Such checks are sometimes necessary to make sure that + memrefs aren't deallocated before their last usage (use after free) or that + some memref isn't deallocated twice (double free). + }]; + + let constructor = + "mlir::bufferization::createBufferDeallocationSimplificationPass()"; + + let dependentDialects = [ + "mlir::bufferization::BufferizationDialect", "mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect" + ]; +} + def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> { let summary = "Optimizes placement of allocation operations by moving them " "into common dominators and out of nested regions"; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -879,57 +879,6 @@ } }; -/// Remove memrefs to be deallocated that are also present in the retained list -/// since they will always alias and thus never actually be deallocated. -/// Example: -/// ```mlir -/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) retain (%arg0 : ...) -/// ``` -/// is canonicalized to -/// ```mlir -/// %0 = bufferization.dealloc retain (%arg0 : ...) -/// ``` -struct DeallocRemoveDeallocMemrefsContainedInRetained - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DeallocOp deallocOp, - PatternRewriter &rewriter) const override { - // Unique memrefs to be deallocated. - DenseMap retained; - for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained())) - retained[ret] = i; - - // There must not be any duplicates in the retain list anymore because we - // would miss updating one of the result values otherwise. - if (retained.size() != deallocOp.getRetained().size()) - return failure(); - - SmallVector newMemrefs, newConditions; - for (auto [memref, cond] : - llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { - if (retained.contains(memref)) { - rewriter.setInsertionPointAfter(deallocOp); - auto orOp = rewriter.create( - deallocOp.getLoc(), - deallocOp.getUpdatedConditions()[retained[memref]], cond); - rewriter.replaceAllUsesExcept( - deallocOp.getUpdatedConditions()[retained[memref]], - orOp.getResult(), orOp); - continue; - } - - newMemrefs.push_back(memref); - newConditions.push_back(cond); - } - - // Return failure if we don't change anything such that we don't run into an - // infinite loop of pattern applications. - return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, - rewriter); - } -}; - /// Erase deallocation operations where the variadic list of memrefs to /// deallocate is empty. Example: /// ```mlir @@ -988,8 +937,7 @@ void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -0,0 +1,171 @@ +//===- BufferDeallocationSimplification.cpp -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for optimizing `bufferization.dealloc` operations +// that requires more analysis than what can be supported by regular +// canonicalization patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/AliasAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace bufferization { +#define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION +#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" +} // namespace bufferization +} // namespace mlir + +using namespace mlir; +using namespace mlir::bufferization; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, + ValueRange memrefs, + ValueRange conditions, + PatternRewriter &rewriter) { + if (deallocOp.getMemrefs() == memrefs && + deallocOp.getConditions() == conditions) + return failure(); + + rewriter.updateRootInPlace(deallocOp, [&]() { + deallocOp.getMemrefsMutable().assign(memrefs); + deallocOp.getConditionsMutable().assign(conditions); + }); + return success(); +} + +//===----------------------------------------------------------------------===// +// Patterns +//===----------------------------------------------------------------------===// + +namespace { + +/// Remove memrefs to be deallocated that are also present in the retained list +/// since they will always alias and thus never actually be deallocated. +/// Example: +/// ```mlir +/// %0 = bufferization.dealloc (%arg0 : ...) if (%arg1) +/// retain (%arg0, %arg1, %arg2 : ...) +/// ``` +/// is canonicalized to +/// ```mlir +/// %0 = bufferization.dealloc retain (%arg0, %arg1, %arg2 : ...) +/// ``` +/// given that `%arg1` and `%arg2` may not alias with `%arg0`. +struct DeallocRemoveDeallocMemrefsContainedInRetained + : public OpRewritePattern { + DeallocRemoveDeallocMemrefsContainedInRetained(MLIRContext *context, + AliasAnalysis &aliasAnalysis) + : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} + + LogicalResult matchAndRewrite(DeallocOp deallocOp, + PatternRewriter &rewriter) const override { + // Unique memrefs to be deallocated. + DenseMap retained; + for (auto [i, ret] : llvm::enumerate(deallocOp.getRetained())) + retained[ret] = i; + + // There must not be any duplicates in the retain list anymore because we + // would miss updating one of the result values otherwise. + if (retained.size() != deallocOp.getRetained().size()) + return failure(); + + SmallVector newMemrefs, newConditions; + for (auto memrefAndCond : + llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + Value memref = std::get<0>(memrefAndCond); + Value cond = std::get<1>(memrefAndCond); + + auto replaceResultsIfNoInvalidAliasing = [&](Value memref) -> bool { + Value retainedMemref = deallocOp.getRetained()[retained[memref]]; + // The current memref must not have a may-alias relation to any retained + // memref, it either has NoAlias or AlwaysAlias. If it's NoAlias, then + // just proceed, if it's AlwaysAlias, we also need to update the + // condition for that alias. Currently, we only support a more + // restrictive case where exactly one memref in `retained` must alias + // and all the other memrefs must not alias. + if (llvm::all_of(deallocOp.getRetained(), [&](Value mr) { + return aliasAnalysis.alias(mr, memref).isNo() || + mr == retainedMemref; + })) { + rewriter.setInsertionPointAfter(deallocOp); + auto orOp = rewriter.create( + deallocOp.getLoc(), + deallocOp.getUpdatedConditions()[retained[memref]], cond); + rewriter.replaceAllUsesExcept( + deallocOp.getUpdatedConditions()[retained[memref]], + orOp.getResult(), orOp); + return true; + } + return false; + }; + + if (retained.contains(memref) && + replaceResultsIfNoInvalidAliasing(memref)) + continue; + + auto extractOp = memref.getDefiningOp(); + if (extractOp && retained.contains(extractOp.getOperand()) && + replaceResultsIfNoInvalidAliasing(extractOp.getOperand())) + continue; + + newMemrefs.push_back(memref); + newConditions.push_back(cond); + } + + // Return failure if we don't change anything such that we don't run into an + // infinite loop of pattern applications. + return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, + rewriter); + } + +private: + AliasAnalysis &aliasAnalysis; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// BufferDeallocationSimplificationPass +//===----------------------------------------------------------------------===// + +namespace { + +/// The actual buffer deallocation pass that inserts and moves dealloc nodes +/// into the right positions. Furthermore, it inserts additional clones if +/// necessary. It uses the algorithm described at the top of the file. +struct BufferDeallocationSimplificationPass + : public bufferization::impl::BufferDeallocationSimplificationBase< + BufferDeallocationSimplificationPass> { + void runOnOperation() override { + AliasAnalysis &aliasAnalysis = getAnalysis(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), + aliasAnalysis); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr +mlir::bufferization::createBufferDeallocationSimplificationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms Bufferize.cpp BufferDeallocation.cpp + BufferDeallocationSimplification.cpp BufferOptimizations.cpp BufferResultsToOutParams.cpp BufferUtils.cpp diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s --buffer-deallocation-simplification --split-input-file | FileCheck %s + +func.func @dealloc_deallocated_in_retained(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) { + %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) + %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) + %2:2 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) + return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 +} + +// CHECK-LABEL: func @dealloc_deallocated_in_retained +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]] +// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] +// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) +// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 : + +// ----- + +func.func @dealloc_deallocated_in_retained_extract_base_memref(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1, i1) { + %base_buffer, %offset, %size, %stride = memref.extract_strided_metadata %arg0 : memref<2xi32> -> memref, index, index, index + %base_buffer0, %offset0, %size0, %stride0 = memref.extract_strided_metadata %arg2 : memref<2xi32> -> memref, index, index, index + %0 = bufferization.dealloc (%base_buffer : memref) if (%arg1) retain (%arg0 : memref<2xi32>) + %1 = bufferization.dealloc (%base_buffer, %base_buffer0 : memref, memref) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) + %2:2 = bufferization.dealloc (%base_buffer : memref) if (%arg1) retain (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) + return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 +} + +// CHECK-LABEL: func @dealloc_deallocated_in_retained_extract_base_memref +// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) +// CHECK-NEXT: [[BASE0:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG0]] : +// CHECK-NEXT: [[BASE1:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ARG2]] : +// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O0:%.+]] = arith.ori [[V0]], [[ARG1]] +// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[BASE1]] : memref) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) +// CHECK-NEXT: [[O1:%.+]] = arith.ori [[V1]], [[ARG1]] +// CHECK-NEXT: [[V2:%.+]]:2 = bufferization.dealloc ([[BASE0]] : memref) if ([[ARG1]]) retain ([[ARG0]], [[ARG2]] : memref<2xi32>, memref<2xi32>) +// CHECK-NEXT: return [[O0]], [[O1]], [[V2]]#0, [[V2]]#1 : diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -297,19 +297,16 @@ // ----- -func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1) { - %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>) - %1 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>) +func.func @dealloc_erase_empty(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> i1 { bufferization.dealloc - bufferization.dealloc retain (%arg0 : memref<2xi32>) - return %0, %1 : i1, i1 + %0 = bufferization.dealloc retain (%arg0 : memref<2xi32>) + return %0 : i1 } -// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated +// CHECK-LABEL: func @dealloc_erase_empty // CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>) -// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>) -// CHECK-NEXT: [[V1:%.+]] = arith.ori [[V0]], [[ARG1]] -// CHECK-NEXT: return [[ARG1]], [[V1]] : +// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false +// CHECK-NEXT: return [[FALSE]] : // -----