diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h @@ -0,0 +1,22 @@ +//===- BufferDeallocationOpInterfaceImpl.h --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H +#define MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace scf { +void registerBufferDeallocationOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace scf +} // namespace mlir + +#endif // MLIR_DIALECT_SCF_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -67,6 +67,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -168,6 +169,7 @@ memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); memref::registerMemorySlotExternalModels(registry); + scf::registerBufferDeallocationOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); scf::registerValueBoundsOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -0,0 +1,87 @@ +//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace { +/// The `scf.forall.in_parallel` terminator is special in a few ways: +/// * It does not implement the BranchOpInterface or +/// RegionBranchTerminatorOpInterface, but the ParallelCombiningOpInterface +/// which is not supported by BufferDeallocation. +/// * It has a graph-like region which only allows one specific tensor op +/// * After bufferization the nested region is always empty +/// For these reasons we provide custom deallocation logic via this external +/// model. +/// +/// Example: +/// ```mlir +/// scf.forall (%arg1) in (%arg0) { +/// %alloc = memref.alloc() : memref<2xf32> +/// ... +/// +/// } +/// ``` +/// gets transformed to +/// ```mlir +/// scf.forall (%arg1) in (%arg0) { +/// %alloc = memref.alloc() : memref<2xf32> +/// ... +/// bufferization.dealloc (%alloc : memref<2xf32>) if (%true) +/// +/// } +/// ``` +struct InParallelOpInterface + : public BufferDeallocationOpInterface::ExternalModel { + FailureOr process(Operation *op, DeallocationState &state, + const DeallocationOptions &options) const { + auto inParallelOp = cast(op); + OpBuilder builder(op); + if (!inParallelOp.getRegion().front().empty()) + return op->emitError("only supported when nested region is empty"); + + // Collect the values to deallocate and retain and use them to create the + // dealloc operation. + Block *block = op->getBlock(); + SmallVector memrefs, conditions, toRetain; + if (failed(state.getMemrefsAndConditionsToDeallocate( + builder, op->getLoc(), block, memrefs, conditions))) + return failure(); + + state.getMemrefsToRetain(block, nullptr, {}, toRetain); + if (memrefs.empty() && toRetain.empty()) + return op; + + auto deallocOp = builder.create( + op->getLoc(), memrefs, conditions, toRetain); + + // We want to replace the current ownership of the retained values with the + // result values of the dealloc operation as they are always unique. + state.resetOwnerships(deallocOp.getRetained(), block); + for (auto [retained, ownership] : + llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) + state.updateOwnership(retained, ownership, block); + + return op; + } +}; + +} // namespace + +void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) { + InParallelOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRSCFTransforms + BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp ForToWhile.cpp diff --git a/mlir/test/Dialect/SCF/buffer-deallocation.mlir b/mlir/test/Dialect/SCF/buffer-deallocation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SCF/buffer-deallocation.mlir @@ -0,0 +1,26 @@ +// DEFINE: %{canonicalize} = -canonicalize=enable-patterns="bufferization-skip-extract-metadata-of-alloc,bufferization-erase-always-false-dealloc,bufferization-erase-empty-dealloc,bufferization-dealloc-remove-duplicate-retained-memrefs,bufferization-dealloc-remove-duplicate-dealloc-memrefs",region-simplify=false + +// RUN: mlir-opt -verify-diagnostics -buffer-deallocation \ +// RUN: %{canonicalize} -buffer-deallocation-simplification %{canonicalize} -split-input-file %s | FileCheck %s + +func.func @parallel_insert_slice_no_conflict(%arg0: index) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<2xf32> + scf.forall (%arg1) in (%arg0) { + %alloc0 = memref.alloc() : memref<2xf32> + %0 = memref.load %alloc[%c0] : memref<2xf32> + linalg.fill ins(%0 : f32) outs(%alloc0 : memref<2xf32>) + } + return +} + +// CHECK-LABEL: func @parallel_insert_slice_no_conflict +// CHECK-SAME: (%arg0: index) +// CHECK: [[ALLOC0:%.+]] = memref.alloc( +// CHECK: scf.forall +// CHECK: [[ALLOC1:%.+]] = memref.alloc( +// CHECK: bufferization.dealloc ([[ALLOC1]] : memref<2xf32>) if (%true +// CHECK-NOT: retain +// CHECK: } +// CHECK: bufferization.dealloc ([[ALLOC0]] : memref<2xf32>) if (%true +// CHECK-NOT: retain