diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -143,6 +143,10 @@ Return the OpResult that aliases with a given OpOperand when bufferized in-place. This method will never be called on OpOperands that do not have a tensor type. + + Note: This method can return multiple OpResults, indicating that a + given OpOperand may at runtime alias with any (or multiple) of the + returned OpResults. }], /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpResult", @@ -165,8 +169,9 @@ return the OpOperands that are yielded by the terminator. Note: This method can return multiple OpOperands, indicating that the - given OpResult may at runtime alias with any of the OpOperands. This - is useful for branches and for ops such as `arith.select`. + given OpResult may at runtime alias with any (or multiple) of the + returned OpOperands. This can be useful for branches and for ops such + as `arith.select`. }], /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpOperand", diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h @@ -0,0 +1,59 @@ +//===- DstBufferizableOpInterfaceImpl.h - Dst Op Bufferization --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_DSTBUFFERIZABLEOPINTERFACEIMPL_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_DSTBUFFERIZABLEOPINTERFACEIMPL_H_ + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +namespace mlir { +namespace bufferization { + +/// Bufferizable ops that implement the DestinationStyleOpInterface can use this +/// external model base class. It provides default implementations for various +/// required interface methods. +template +struct DstBufferizableOpInterfaceExternalModel + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // All inputs and outputs bufferize to a memory read. + assert(isa(op) && + "expected that op implements DestinationStyleOpInterface"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Only outputs bufferize to a memory write. + auto dstOp = cast(op); + return dstOp.isOutput(&opOperand); + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Output operands alias with their respective tied OpResults. + auto dstOp = cast(op); + if (dstOp.isOutput(&opOperand)) + return {dstOp.getTiedOpResult(&opOperand)}; + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + assert(isa(op) && + "expected that op implements DestinationStyleOpInterface"); + return BufferRelation::Equivalent; + } +}; + +} // namespace bufferization +} // namespace mlir + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_DSTBUFFERIZABLEOPINTERFACEIMPL_H_ diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect + MLIRDestinationStyleOpInterface MLIRDialect MLIRFuncDialect MLIRIR diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -563,32 +564,12 @@ }; /// Bufferization of tensor.insert. Replace with memref.store. +/// +/// Note: DstBufferizableOpInterfaceExternalModel provides many default method +/// implementations for DestinationStyle ops. struct InsertOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - assert(&opOperand == &op->getOpOperand(1) /*dest*/ && - "expected dest OpOperand"); - return {op->getOpResult(0)}; - } - - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return {&op->getOpOperand(1) /*dest*/}; - } - + : public DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto insertOp = cast(op); @@ -601,11 +582,6 @@ replaceOpWithBufferizedValues(rewriter, op, *destMemref); return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. @@ -732,31 +708,12 @@ /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under /// certain circumstances, this op can also be a no-op. +/// +/// Note: DstBufferizableOpInterfaceExternalModel provides many default method +/// implementations for DestinationStyle ops. struct InsertSliceOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return &opOperand == &op->getOpOperand(1) /*dest*/; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(1) /*dest*/) - return {op->getResult(0)}; - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } - + : public DstBufferizableOpInterfaceExternalModel { bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state) const { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9817,6 +9817,7 @@ hdrs = [ "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", "include/mlir/Dialect/Bufferization/IR/Bufferization.h", + "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h", ], includes = ["include"], deps = [ @@ -9828,6 +9829,7 @@ ":BufferizationOpsIncGen", ":ControlFlowInterfaces", ":CopyOpInterface", + ":DestinationStyleOpInterface", ":FuncDialect", ":IR", ":InferTypeOpInterface",