diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h @@ -0,0 +1,22 @@ +//===- BufferDeallocationOpInterfaceImpl.h --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H +#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace arith { +void registerBufferDeallocationOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h @@ -142,8 +142,8 @@ /// a new SSA value, returned as the first element of the pair, which has /// 'Unique' ownership and can be used instead of the passed Value with the /// the ownership indicator returned as the second element of the pair. - std::pair getMemrefWithUniqueOwnership(OpBuilder &builder, - Value memref); + std::pair + getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block); /// Given two basic blocks and the values passed via block arguments to the /// destination block, compute the list of MemRefs that have to be retained in diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td @@ -39,7 +39,34 @@ /*retType=*/"FailureOr", /*methodName=*/"process", /*args=*/(ins "DeallocationState &":$state, - "const DeallocationOptions &":$options)> + "const DeallocationOptions &":$options)>, + InterfaceMethod< + /*desc=*/[{ + This method allows the implementing operation to specify custom logic + to materialize an ownership indicator value for the given MemRef typed + value it defines (including block arguments of nested regions). Since + the operation itself has more information about its semantics the + materialized IR can be more efficient compared to the default + implementation and avoid cloning MemRefs and/or doing alias checking + at runtime. + Note that the same logic could also be implemented in the 'process' + method above, however, the IR is always materialized then. If + it's desirable to only materialize the IR to compute an updated + ownership indicator when needed, it should be implemented using this + method (which is especially important if operations are created that + cannot be easily canonicalized away anymore). + }], + /*retType=*/"std::pair", + /*methodName=*/"materializeUniqueOwnershipForMemref", + /*args=*/(ins "DeallocationState &":$state, + "const DeallocationOptions &":$options, + "OpBuilder &":$builder, + "Value":$memref), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return state.getMemrefWithUniqueOwnership( + builder, memref, memref.getParentBlock()); + }]>, ]; } diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" @@ -132,6 +133,7 @@ // Register all external models. affine::registerValueBoundsOpInterfaceExternalModels(registry); + arith::registerBufferDeallocationOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); arith::registerValueBoundsOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -0,0 +1,87 @@ +//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bufferization; + +static bool isMemref(Value v) { return v.getType().isa(); } + +namespace { +/// Provides custom logic to materialize ownership indicator values for the +/// result value of 'arith.select'. Instead of cloning or runtime alias +/// checking, this implementation inserts another `arith.select` to choose the +/// ownership indicator of the operand in the same way the original +/// `arith.select` chooses the MemRef operand. If at least one of the operand's +/// ownerships is 'Unknown', fall back to the default implementation. +/// +/// Example: +/// ```mlir +/// // let ownership(%m0) := %o0 +/// // let ownership(%m1) := %o1 +/// %res = arith.select %cond, %m0, %m1 +/// ``` +/// The default implementation would insert a clone and replace all uses of the +/// result of `arith.select` with that clone: +/// ```mlir +/// %res = arith.select %cond, %m0, %m1 +/// %clone = bufferization.clone %res +/// // let ownership(%res) := 'Unknown' +/// // let ownership(%clone) := %true +/// // replace all uses of %res with %clone +/// ``` +/// This implementation, on the other hand, materializes the following: +/// ```mlir +/// %res = arith.select %cond, %m0, %m1 +/// %res_ownership = arith.select %cond, %o0, %o1 +/// // let ownership(%res) := %res_ownership +/// ``` +struct SelectOpInterface + : public BufferDeallocationOpInterface::ExternalModel { + FailureOr process(Operation *op, DeallocationState &state, + const DeallocationOptions &options) const { + return op; // nothing to do + } + + std::pair + materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state, + const DeallocationOptions &options, + OpBuilder &builder, Value value) const { + auto selectOp = cast(op); + assert(value == selectOp.getResult() && + "Value not defined by this operation"); + + Block *block = value.getParentBlock(); + if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() || + !state.getOwnership(selectOp.getFalseValue(), block).isUnique()) + return state.getMemrefWithUniqueOwnership(builder, value, + value.getParentBlock()); + + Value ownership = builder.create( + op->getLoc(), selectOp.getCondition(), + state.getOwnership(selectOp.getTrueValue(), block).getIndicator(), + state.getOwnership(selectOp.getFalseValue(), block).getIndicator()); + return {selectOp.getResult(), ownership}; + } +}; + +} // namespace + +void mlir::arith::registerBufferDeallocationOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) { + SelectOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRArithTransforms + BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp EmulateUnsupportedFloats.cpp diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp @@ -134,8 +134,8 @@ std::pair DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, - Value memref) { - auto iter = ownershipMap.find({memref, memref.getParentBlock()}); + Value memref, Block *block) { + auto iter = ownershipMap.find({memref, block}); assert(iter != ownershipMap.end() && "Value must already have been registered in the ownership map"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -376,13 +376,24 @@ /// Given an SSA value of MemRef type, returns the same of a new SSA value /// which has 'Unique' ownership where the ownership indicator is guaranteed /// to be always 'true'. - Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref); + Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder, + Value memref, Block *block); /// Returns whether the given operation implements FunctionOpInterface, has /// private visibility, and the private-function-dynamic-ownership pass option /// is enabled. bool isFunctionWithoutDynamicOwnership(Operation *op); + /// Given an SSA value of MemRef type, this function queries the + /// BufferDeallocationOpInterface of the defining operation of 'memref' for a + /// materialized ownership indicator for 'memref'. If the op does not + /// implement the interface or if the block for which the materialized value + /// is requested does not match the block in which 'memref' is defined, the + /// default implementation in + /// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead. + std::pair + materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block); + /// Checks all the preconditions for operations implementing the /// FunctionOpInterface that have to hold for the deallocation to be /// applicable: @@ -428,6 +439,28 @@ // BufferDeallocation Implementation //===----------------------------------------------------------------------===// +std::pair +BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref, + Block *block) { + // The interface can only materialize ownership indicators in the same block + // as the defining op. + if (memref.getParentBlock() != block) + return state.getMemrefWithUniqueOwnership(builder, memref, block); + + Operation *owner = memref.getDefiningOp(); + if (!owner) + owner = memref.getParentBlock()->getParentOp(); + + // If the op implements the interface, query it for a materialized ownership + // value. + if (auto deallocOpInterface = dyn_cast(owner)) + return deallocOpInterface.materializeUniqueOwnershipForMemref( + state, options, builder, memref); + + // Otherwise use the default implementation. + return state.getMemrefWithUniqueOwnership(builder, memref, block); +} + static bool regionOperatesOnMemrefValues(Region ®ion) { WalkResult result = region.walk([](Block *block) { if (llvm::any_of(block->getArguments(), isMemref)) @@ -677,11 +710,11 @@ return newOp.getOperation(); } -Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder, - Value memref) { +Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership( + OpBuilder &builder, Value memref, Block *block) { // First, make sure we at least have 'Unique' ownership already. std::pair newMemrefAndOnwership = - state.getMemrefWithUniqueOwnership(builder, memref); + materializeUniqueOwnership(builder, memref, block); Value newMemref = newMemrefAndOnwership.first; Value condition = newMemrefAndOnwership.second; @@ -785,7 +818,7 @@ continue; } auto [memref, condition] = - state.getMemrefWithUniqueOwnership(builder, operand); + materializeUniqueOwnership(builder, operand, op->getBlock()); newOperands.push_back(memref); ownershipIndicatorsToAdd.push_back(condition); } @@ -868,7 +901,8 @@ if (!isMemref(val.get())) continue; - val.set(getMemrefWithGuaranteedOwnership(builder, val.get())); + val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(), + op->getBlock())); } } diff --git a/mlir/test/Dialect/Bufferization/Transforms/BufferDeallocation/dealloc-callop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/BufferDeallocation/dealloc-callop-interface.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/BufferDeallocation/dealloc-callop-interface.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/BufferDeallocation/dealloc-callop-interface.mlir @@ -95,15 +95,15 @@ // CHECK-NEXT: return // CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block +// CHECK-DYNAMIC-SAME: ([[ARG0:%.+]]: i1) // CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc( // CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca( -// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]] -// CHECK-DYNAMIC-NEXT: [[CLONE:%.+]] = bufferization.clone [[SELECT]] -// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}}) +// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]] +// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]]) // CHECK-DYNAMIC-NEXT: test.copy // CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0 -// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] : -// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1) +// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] : +// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, [[RET]]#1) // CHECK-DYNAMIC-NOT: retain // CHECK-DYNAMIC-NEXT: return