diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -11,6 +11,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SetVector.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -56,6 +56,12 @@ "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ + if (isa($_op.getOperation())) { + // Default for all destination style ops: All Inputs and outputs + // bufferize to a memory read. + return true; + } + // Does not have to be implemented for ops without tensor OpOperands. llvm_unreachable("bufferizesToMemoryRead not implemented"); }] @@ -85,6 +91,13 @@ "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ + if (auto dstOp = + dyn_cast($_op.getOperation())) { + // Default for all destination style ops: Only outputs bufferize to + // a memory write. + return dstOp.isOutput(&opOperand); + } + // Does not have to be implemented for ops without tensor OpOperands. // Does not have to be implemented for OpOperands that do not have an // aliasing OpResult. @@ -143,6 +156,10 @@ Return the OpResult that aliases with a given OpOperand when bufferized in-place. This method will never be called on OpOperands that do not have a tensor type. + + Note: This method can return multiple OpResults, indicating that a + given OpOperand may at runtime alias with any (or multiple) of the + returned OpResults. }], /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpResult", @@ -150,6 +167,15 @@ "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ + if (auto dstOp = + dyn_cast($_op.getOperation())) { + // Default for all destination style ops: Output operands alias with + // their respective tied OpResults. + if (dstOp.isOutput(&opOperand)) + return { dstOp.getTiedOpResult(&opOperand) }; + return {}; + } + // Does not have to be implemented for ops without tensor OpOperands. llvm_unreachable("getAliasingOpResult not implemented"); }] @@ -165,8 +191,9 @@ return the OpOperands that are yielded by the terminator. Note: This method can return multiple OpOperands, indicating that the - given OpResult may at runtime alias with any of the OpOperands. This - is useful for branches and for ops such as `arith.select`. + given OpResult may at runtime alias with any (or multiple) of the + returned OpOperands. This can be useful for branches and for ops such + as `arith.select`. }], /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpOperand", @@ -207,6 +234,12 @@ "const AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ + if (auto dstOp = + dyn_cast($_op.getOperation())) { + // Default for all destination style ops: Equivalent buffers. + return BufferRelation::Equivalent; + } + // Does not have to be implemented for ops without tensor OpResults // that have an aliasing OpOperand. llvm_unreachable("bufferRelation not implemented"); diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect + MLIRDestinationStyleOpInterface MLIRDialect MLIRFuncDialect MLIRIR diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -563,32 +563,12 @@ }; /// Bufferization of tensor.insert. Replace with memref.store. +/// +/// Note: InsertOp implements DestinationStyleOpInterface, so it is sufficient +/// to implement only the `bufferize` interface method. struct InsertOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - assert(&opOperand == &op->getOpOperand(1) /*dest*/ && - "expected dest OpOperand"); - return {op->getOpResult(0)}; - } - - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return {&op->getOpOperand(1) /*dest*/}; - } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto insertOp = cast(op); @@ -601,11 +581,6 @@ replaceOpWithBufferizedValues(rewriter, op, *destMemref); return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. @@ -732,31 +707,12 @@ /// Bufferization of tensor.insert_slice. Replace with a memory copy. Under /// certain circumstances, this op can also be a no-op. +/// +/// Note: InsertSliceOp implements DestinationStyleOpInterface for which the +/// BufferizableOpInterface provides many default method implementations. struct InsertSliceOpInterface : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return &opOperand == &op->getOpOperand(1) /*dest*/; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(1) /*dest*/) - return {op->getResult(0)}; - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } - bool isNotConflicting(Operation *op, OpOperand *uRead, OpOperand *uConflictingWrite, const AnalysisState &state) const { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9612,6 +9612,7 @@ ":BufferizationOpsIncGen", ":ControlFlowInterfaces", ":CopyOpInterface", + ":DestinationStyleOpInterface", ":FuncDialect", ":IR", ":InferTypeOpInterface",