diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -21,6 +21,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -1270,7 +1271,8 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - AttrSizedOperandSegments + AttrSizedOperandSegments, + DestinationStyleOpInterface ]>, Arguments<(ins AnyVectorOfAnyRank:$vector, AnyShaped:$source, @@ -1393,6 +1395,10 @@ /// This method is added to maintain uniformity with load/store /// ops of other dialects. Value getValue() { return getVector(); } + + std::pair getOutputsPositionRange() { + return {1, 2}; // `source` operand + } }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -12,6 +12,7 @@ MLIRArithDialect MLIRControlFlowInterfaces MLIRDataLayoutInterfaces + MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR MLIRMaskingInterfaces diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -63,35 +64,12 @@ /// Bufferization of vector.transfer_write. Replace with a new /// vector.transfer_write that operates on a memref. +/// +/// Note: DstBufferizableOpInterfaceExternalModel provides many default method +/// implementations for DestinationStyle ops. struct TransferWriteOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return true; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } - + : public DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto writeOp = cast(op); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3246,6 +3246,7 @@ ":ArithDialect", ":ArithUtils", ":ControlFlowInterfaces", + ":DestinationStyleOpInterface", ":DialectUtils", ":IR", ":InferTypeOpInterface", @@ -8211,6 +8212,7 @@ includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":DestinationStyleOpInterfaceTdFiles", ":InferTypeOpInterfaceTdFiles", ":MaskingInterfacesTdFiles", ":OpBaseTdFiles",