diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h @@ -0,0 +1,27 @@ +//===- VectorInterfaceImpl.h - Vector Impl. of BufferizableOpInterface ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_VECTOR_INTERFACE_IMPL_H +#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_VECTOR_INTERFACE_IMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace linalg { +namespace comprehensive_bufferize { +namespace vector_ext { + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace vector_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_VECTOR_INTERFACE_IMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -3,6 +3,7 @@ ComprehensiveBufferize.cpp LinalgInterfaceImpl.cpp TensorInterfaceImpl.cpp + VectorInterfaceImpl.cpp ) add_mlir_dialect_library(MLIRBufferizableOpInterface @@ -36,6 +37,15 @@ MLIRTensor ) +add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl + VectorInterfaceImpl.cpp + + LINK_LIBS PUBLIC + MLIRBufferizableOpInterface + MLIRIR + MLIRVector +) + add_mlir_dialect_library(MLIRComprehensiveBufferize ComprehensiveBufferize.cpp @@ -48,5 +58,4 @@ MLIRSCF MLIRStandard MLIRStandardOpsTransforms - MLIRVector ) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -114,7 +114,6 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Operation.h" @@ -1926,102 +1925,6 @@ } // namespace std_ext -namespace vector_ext { - -struct TransferReadOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return false; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - return OpResult(); - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto transferReadOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // TransferReadOp always reads from the bufferized op.source(). - assert(transferReadOp.getShapedType().isa() && - "only tensor types expected"); - Value v = state.lookupBuffer(transferReadOp.source()); - transferReadOp.sourceMutable().assign(v); - return success(); - } -}; - -struct TransferWriteOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return true; - } - - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {&op->getOpOperand(1)}; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { - assert(opOperand.get().getType().isa() && - "only tensor types expected"); - return op->getOpResult(0); - } - - BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { - return BufferRelation::Equivalent; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto writeOp = cast(op); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - - // Create a new transfer_write on buffer that doesn't have a return value. - // Leave the previous transfer_write to dead code as it still has uses at - // this point. - assert(writeOp.getShapedType().isa() && - "only tensor types expected"); - Value resultBuffer = getResultBuffer(b, op->getResult(0), state); - if (!resultBuffer) - return failure(); - b.create( - writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), - writeOp.permutation_map(), - writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); - state.mapBuffer(op->getResult(0), resultBuffer); - - return success(); - } -}; - -} // namespace vector_ext - void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); // Ops that are not bufferizable but are allocation hoisting barriers. registry.addOpInterface>(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp @@ -0,0 +1,123 @@ +//===- VectorInterfaceImpl.cpp - Vector Impl. of BufferizableOpInterface --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace linalg { +namespace comprehensive_bufferize { +namespace vector_ext { + +struct TransferReadOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto transferReadOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + // TransferReadOp always reads from the bufferized op.source(). + assert(transferReadOp.getShapedType().isa() && + "only tensor types expected"); + Value v = state.lookupBuffer(transferReadOp.source()); + transferReadOp.sourceMutable().assign(v); + return success(); + } +}; + +struct TransferWriteOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return true; + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(1)}; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + assert(opOperand.get().getType().isa() && + "only tensor types expected"); + return op->getOpResult(0); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto writeOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + + // Create a new transfer_write on buffer that doesn't have a return value. + // Leave the previous transfer_write to dead code as it still has uses at + // this point. + assert(writeOp.getShapedType().isa() && + "only tensor types expected"); + Value resultBuffer = getResultBuffer(b, op->getResult(0), state); + if (!resultBuffer) + return failure(); + b.create( + writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), + writeOp.permutation_map(), + writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr()); + state.mapBuffer(op->getResult(0), resultBuffer); + + return success(); + } +}; + +} // namespace vector_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +void mlir::linalg::comprehensive_bufferize::vector_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -53,6 +53,7 @@ MLIRTransforms MLIRTransformUtils MLIRVector + MLIRVectorBufferizableOpInterfaceImpl MLIRX86VectorTransforms MLIRVectorToSCF ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -40,6 +41,7 @@ registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); + vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } }; } // end namespace diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6345,6 +6345,24 @@ ], ) +cc_library( + name = "VectorBufferizableOpInterfaceImpl", + srcs = [ + "lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h", + ], + includes = ["include"], + deps = [ + ":BufferizableOpInterface", + ":IR", + ":Support", + ":VectorOps", + "//llvm:Support", + ], +) + td_library( name = "LinalgDocTdFiles", srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"], @@ -6567,6 +6585,7 @@ ":TensorBufferizableOpInterfaceImpl", ":TensorDialect", ":TransformUtils", + ":VectorBufferizableOpInterfaceImpl", ":VectorOps", ":VectorToSCF", ":X86VectorTransforms", @@ -6596,7 +6615,6 @@ ":StandardOps", ":Support", ":TransformUtils", - ":VectorOps", "//llvm:Support", ], )