diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATION_H_ #include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h +++ /dev/null @@ -1,25 +0,0 @@ -//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_ -#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_ - -namespace mlir { - -class DialectRegistry; - -namespace bufferization { -namespace bufferization_ext { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace bufferization_ext -} // namespace bufferization -} // namespace mlir - -#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONINTERFACEIMPL_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -10,6 +10,7 @@ #define BUFFERIZATION_OPS include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" +include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizationBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" @@ -64,11 +65,14 @@ // ToTensorOp //===----------------------------------------------------------------------===// -def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", - [SameOperandsAndResultShape, SameOperandsAndResultElementType, - TypesMatchWith<"result type matches tensor equivalent of 'memref'", - "memref", "result", - "memref::getTensorTypeFromMemRefType($_self)">]> { +def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ + BufferizableOpInterface, + SameOperandsAndResultShape, + SameOperandsAndResultElementType, + TypesMatchWith<"result type matches tensor equivalent of 'memref'", + "memref", "result", + "memref::getTensorTypeFromMemRefType($_self)"> + ]> { let summary = "memref to tensor operation"; let description = [{ Create a tensor from a memref, making an independent copy of the element @@ -110,6 +114,35 @@ return resultType.cast(); return {}; } + + //===------------------------------------------------------------------===// + // BufferizableOpInterface implementation + //===------------------------------------------------------------------===// + + // ToTensorOp conceptually loads a tensor from a memory location. The + // One-Shot analysis has no information about the memref that is loaded from + // by ToTensorOp. We have to assume that the loaded tensor may after + // bufferization potentially alias with any other bufferized tensor. Since + // ToTensorOp and ToMemrefOp have no aliasing OpOperand/OpResult pairs, this + // cannot be encoded directly in the analysis. However, declaring ToTensorOp + // results as not writable enforces a buffer copy and has the same effect. + + LogicalResult bufferize(RewriterBase &rewriter, + const BufferizationState &state) const { + // to_tensor cannot be bufferized. However, other ops that are using + // to_tensor's result will eventually be bufferized. At that point, they + // will start using to_tensor's memref operand. Once all users of + // to_tensor are bufferized, the op will not have any users anymore and + // DCE away. In case of partial bufferization, to_memref(to_tensor(x)) + // constructs may be left over. These are folded by the canonicalizer or + // FinalizingBufferize. + return failure(); + } + + bool isWritable(Value value, const BufferizationState &state) const { + // It is unknown whether the memref operand is writable or not. + return false; + } }]; let assemblyFormat = "$memref attr-dict `:` type($memref)"; @@ -123,11 +156,15 @@ // ToMemrefOp //===----------------------------------------------------------------------===// -def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", - [SameOperandsAndResultShape, SameOperandsAndResultElementType, NoSideEffect, - TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'", - "memref", "tensor", - "memref::getTensorTypeFromMemRefType($_self)">]> { +def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [ + BufferizableOpInterface, + SameOperandsAndResultShape, + SameOperandsAndResultElementType, + NoSideEffect, + TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'", + "memref", "tensor", + "memref::getTensorTypeFromMemRefType($_self)"> + ]> { let summary = "tensor to memref cast operation"; let description = [{ Casts a tensor to a memref. @@ -150,6 +187,44 @@ // This op is fully verified by traits. let verifier = ?; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // BufferizableOpInterface implementation + //===------------------------------------------------------------------===// + + // Note: ToMemrefOp / ToTensorOp are temporary ops that are inserted at the + // bufferization boundary. When One-Shot bufferization is complete, there + // should be no such ops left over. If `allowUnknownOps` (or after running a + // partial bufferization pass), such ops may be part of the resulting IR, + // but such IR may no longer be analyzable by One-Shot analysis. + + bool bufferizesToMemoryRead(OpOperand &opOperand, + const BufferizationState &state) const { + // It is unknown whether the resulting memref will be read or not. + return true; + } + + bool bufferizesToMemoryWrite(OpOperand &opOperand, + const BufferizationState &state) const { + // It is unknown whether the resulting MemRef will be written or not. + return true; + } + + bool mustBufferizeInPlace(OpOperand &opOperand, + const BufferizationState &state) const { + // ToMemrefOps always bufferize inplace. + return true; + } + + OpResult getAliasingOpResult(OpOperand &opOperand, + const BufferizationState &state) const { + return OpResult(); + } + + LogicalResult bufferize(RewriterBase &rewriter, + const BufferizationState &state); + }]; + let assemblyFormat = "$tensor attr-dict `:` type($memref)"; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp +++ /dev/null @@ -1,127 +0,0 @@ -//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h" -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Operation.h" - -using namespace mlir; -using namespace mlir::bufferization; - -namespace mlir { -namespace bufferization { -namespace bufferization_ext { - -// TODO: These ops should implement BufferizableOpInterface. - -/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded -/// to x. Other to_memref ops are ignored during bufferization. -/// -/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory -/// location of the incoming tensor once it will be bufferized. In the anlysis, -/// the incoming tensor is assumed to bufferize to a memory read and to an -/// inplace memory write, since it is unknown what will happen to the resulting -/// memref. -/// -/// Note: ToMemrefOp / ToTensorOp are temporary ops that are inserted at the -/// bufferization boundary. When bufferization is complete, there should be no -/// such ops left over. If `allowUnknownOps`, such ops may be part of the -/// resulting IR, but such IR may no longer be bufferizable by Comprehensive -/// Bufferize. -struct ToMemrefOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - // It is unknown whether the resulting memref will be read or not. - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - // It is unknown whether the resulting MemRef will be written or not. - return true; - } - - bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - // ToMemrefOps always bufferize inplace. - return true; - } - - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { - auto toMemrefOp = cast(op); - - // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. - if (auto toTensorOp = - toMemrefOp.tensor().getDefiningOp()) { - Value buffer = toTensorOp.memref(); - - // Insert cast in case to_memref(to_tensor(x))'s type is different from - // x's type. - if (toTensorOp.memref().getType() != toMemrefOp.getType()) { - assert(memref::CastOp::areCastCompatible(buffer.getType(), - toMemrefOp.getType()) && - "ToMemrefOp::bufferize : cast incompatible"); - buffer = rewriter.create(toMemrefOp.getLoc(), buffer, - toMemrefOp.getType()); - } - replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer); - return success(); - } - - return failure(); - } -}; - -/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized. -/// However, other ops that are using to_tensor's result will eventually be -/// bufferized. At that point, they will start using to_tensor's memref operand. -/// Once all users of to_tensor are bufferized, the op will not have any users -/// anymore and DCE away. -/// -/// ToTensorOp conceptually loads a tensor from a memory location. The analysis -/// has no information about the memref that is loaded from by ToTensorOp. We -/// have to assume that the loaded tensor may after bufferization potentially -/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have -/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the -/// analysis. However, declaring ToTensorOp results as not writable enforces a -/// buffer copy and has the same effect. -struct ToTensorOpInterface - : public BufferizableOpInterface::ExternalModel { - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - const BufferizationState &state) const { - return failure(); - } - - bool isWritable(Operation *op, Value value, - const BufferizationState &state) const { - // It is unknown whether the memref operand is writable or not. - return false; - } -}; - -} // namespace bufferization_ext -} // namespace bufferization -} // namespace mlir - -void bufferization_ext::registerBufferizableOpInterfaceExternalModels( - DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); -} diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -182,6 +182,79 @@ } }; +/// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the +/// to_memref op are different, a memref.cast is needed. +static LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, + ToMemrefOp toMemref, + bool allowSameType = true) { + auto memrefToTensor = toMemref.tensor().getDefiningOp(); + if (!memrefToTensor) + return failure(); + + // A memref_to_tensor + tensor_to_memref with same types can be folded without + // inserting a cast. + if (memrefToTensor.memref().getType() == toMemref.getType()) { + if (!allowSameType) + // Function can be configured to only handle cases where a cast is needed. + return failure(); + rewriter.replaceOp(toMemref, memrefToTensor.memref()); + return success(); + } + + // If types are definitely not cast-compatible, bail. + if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(), + toMemref.getType())) + return failure(); + + // We already know that the types are potentially cast-compatible. However + // in case the affine maps are different, we may need to use a copy if we go + // from dynamic to static offset or stride (the canonicalization cannot know + // at this point that it is really cast compatible). + auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { + int64_t sourceOffset, targetOffset; + SmallVector sourceStrides, targetStrides; + if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || + failed(getStridesAndOffset(target, targetStrides, targetOffset))) + return false; + auto dynamicToStatic = [](int64_t a, int64_t b) { + return a == MemRefType::getDynamicStrideOrOffset() && + b != MemRefType::getDynamicStrideOrOffset(); + }; + if (dynamicToStatic(sourceOffset, targetOffset)) + return false; + for (auto it : zip(sourceStrides, targetStrides)) + if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) + return false; + return true; + }; + + auto memrefToTensorType = + memrefToTensor.memref().getType().dyn_cast(); + auto toMemrefType = toMemref.getType().dyn_cast(); + if (memrefToTensorType && toMemrefType && + !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) { + MemRefType resultType = toMemrefType; + auto loc = toMemref.getLoc(); + SmallVector dynamicOperands; + for (int i = 0; i < resultType.getRank(); ++i) { + if (resultType.getShape()[i] != ShapedType::kDynamicSize) + continue; + auto index = rewriter.createOrFold(loc, i); + Value size = rewriter.create(loc, memrefToTensor, index); + dynamicOperands.push_back(size); + } + // TODO: Use alloc/memcpy callback from BufferizationOptions if called via + // BufferizableOpInterface impl of ToMemrefOp. + auto copy = + rewriter.create(loc, resultType, dynamicOperands); + rewriter.create(loc, memrefToTensor.memref(), copy); + rewriter.replaceOp(toMemref, {copy}); + } else + rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), + memrefToTensor.memref()); + return success(); +} + /// Canonicalize bufferization.to_tensor + bufferization.to_memref to /// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. struct TensorLoadToMemref : public OpRewritePattern { @@ -189,62 +262,10 @@ LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { - auto memrefToTensor = toMemref.tensor().getDefiningOp(); - // Bail unless we have a memref_to_tensor + tensor_to_memref with different - // types. `ToMemrefOp::fold` handles the same type case. - if (!memrefToTensor || - memrefToTensor.memref().getType() == toMemref.getType()) - return failure(); - // If types are definitely not cast-compatible, bail. - if (!memref::CastOp::areCastCompatible(memrefToTensor.memref().getType(), - toMemref.getType())) - return failure(); - - // We already know that the types are potentially cast-compatible. However - // in case the affine maps are different, we may need to use a copy if we go - // from dynamic to static offset or stride (the canonicalization cannot know - // at this point that it is really cast compatible). - auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) { - int64_t sourceOffset, targetOffset; - SmallVector sourceStrides, targetStrides; - if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) || - failed(getStridesAndOffset(target, targetStrides, targetOffset))) - return false; - auto dynamicToStatic = [](int64_t a, int64_t b) { - return a == MemRefType::getDynamicStrideOrOffset() && - b != MemRefType::getDynamicStrideOrOffset(); - }; - if (dynamicToStatic(sourceOffset, targetOffset)) - return false; - for (auto it : zip(sourceStrides, targetStrides)) - if (dynamicToStatic(std::get<0>(it), std::get<1>(it))) - return false; - return true; - }; - - auto memrefToTensorType = - memrefToTensor.memref().getType().dyn_cast(); - auto toMemrefType = toMemref.getType().dyn_cast(); - if (memrefToTensorType && toMemrefType && - !isGuaranteedCastCompatible(memrefToTensorType, toMemrefType)) { - MemRefType resultType = toMemrefType; - auto loc = toMemref.getLoc(); - SmallVector dynamicOperands; - for (int i = 0; i < resultType.getRank(); ++i) { - if (resultType.getShape()[i] != ShapedType::kDynamicSize) - continue; - auto index = rewriter.createOrFold(loc, i); - Value size = rewriter.create(loc, memrefToTensor, index); - dynamicOperands.push_back(size); - } - auto copy = - rewriter.create(loc, resultType, dynamicOperands); - rewriter.create(loc, memrefToTensor.memref(), copy); - rewriter.replaceOp(toMemref, {copy}); - } else - rewriter.replaceOpWithNewOp(toMemref, toMemref.getType(), - memrefToTensor.memref()); - return success(); + // Only handle cases where a cast is needed. The other case is handled by + // the folder. + return foldToMemrefToTensorPair(rewriter, toMemref, + /*allowSameType=*/false); } }; @@ -288,6 +309,12 @@ context); } +LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, + const BufferizationState &state) { + // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. + return foldToMemrefToTensorPair(rewriter, *this); +} + Optional CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { return builder.create(alloc.getLoc(), alloc) .getOperation(); diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRBufferization - PARTIAL_SOURCES_INTENDED AllocationOpInterface.cpp + BufferizableOpInterface.cpp BufferizationOps.cpp BufferizationDialect.cpp @@ -17,17 +17,3 @@ MLIRTensor MLIRMemRef ) - -add_mlir_dialect_library(MLIRBufferizableOpInterface - PARTIAL_SOURCES_INTENDED - BufferizableOpInterface.cpp - BufferizationInterfaceImpl.cpp - - DEPENDS - MLIRBufferizableOpInterfaceIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRBufferization - MLIRMemRef -) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -10,7 +10,6 @@ MLIRBufferizationPassIncGen LINK_LIBS PUBLIC - MLIRBufferizableOpInterface MLIRBufferization MLIRControlFlowInterfaces MLIRInferTypeOpInterface diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -13,7 +13,7 @@ LINK_LIBS PUBLIC MLIRAffine - MLIRBufferizableOpInterface + MLIRBufferization ) add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl @@ -21,7 +21,7 @@ LINK_LIBS PUBLIC MLIRArithmetic - MLIRBufferizableOpInterface + MLIRBufferization MLIRIR MLIRMemRef MLIRStandardOpsTransforms @@ -31,7 +31,7 @@ LinalgInterfaceImpl.cpp LINK_LIBS PUBLIC - MLIRBufferizableOpInterface + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRLinalg @@ -42,7 +42,7 @@ SCFInterfaceImpl.cpp LINK_LIBS PUBLIC - MLIRBufferizableOpInterface + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRSCF @@ -52,7 +52,7 @@ StdInterfaceImpl.cpp LINK_LIBS PUBLIC - MLIRBufferizableOpInterface + MLIRBufferization MLIRStandard ) @@ -60,7 +60,7 @@ VectorInterfaceImpl.cpp LINK_LIBS PUBLIC - MLIRBufferizableOpInterface + MLIRBufferization MLIRIR MLIRVector ) @@ -69,7 +69,7 @@ ModuleBufferization.cpp LINK_LIBS PUBLIC - MLIRBufferizableOpInterface + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRMemRef diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -14,7 +14,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRArithmetic - MLIRBufferizableOpInterface + MLIRBufferization MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -36,7 +36,7 @@ MLIRAnalysis MLIRArithBufferizableOpInterfaceImpl MLIRArithmetic - MLIRBufferizableOpInterface + MLIRBufferization MLIRComplex MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -10,7 +10,6 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" @@ -54,7 +53,6 @@ arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); affine_ext::registerBufferizableOpInterfaceExternalModels(registry); arith_ext::registerBufferizableOpInterfaceExternalModels(registry); - bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -12,7 +12,7 @@ LINK_LIBS PUBLIC MLIRArithmetic - MLIRBufferizableOpInterface + MLIRBufferization MLIRIR MLIRLLVMIR MLIRLinalg diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -10,7 +10,7 @@ LINK_LIBS PUBLIC MLIRArithmetic - MLIRBufferizableOpInterface + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRMemRef diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -16,7 +16,7 @@ MLIRAffineBufferizableOpInterfaceImpl MLIRArithBufferizableOpInterfaceImpl MLIRArithmetic - MLIRBufferizableOpInterface + MLIRBufferization MLIRBufferizationTransforms MLIRGPUTransforms MLIRLinalg diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" @@ -61,7 +60,6 @@ arith::ArithmeticDialect, AffineDialect>(); affine_ext::registerBufferizableOpInterfaceExternalModels(registry); arith_ext::registerBufferizableOpInterfaceExternalModels(registry); - bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1958,7 +1958,6 @@ deps = [ ":Affine", ":ArithmeticDialect", - ":BufferizableOpInterface", ":BufferizationDialect", ":IR", ":LLVMDialect", @@ -4423,7 +4422,6 @@ deps = [ ":ArithmeticDialect", ":Async", - ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationTransforms", ":IR", @@ -6573,27 +6571,6 @@ ], ) -cc_library( - name = "BufferizableOpInterface", - srcs = [ - "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp", - "lib/Dialect/Bufferization/IR/BufferizationInterfaceImpl.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", - "include/mlir/Dialect/Bufferization/IR/BufferizationInterfaceImpl.h", - ], - includes = ["include"], - deps = [ - ":BufferizableOpInterfaceIncGen", - ":BufferizationDialect", - ":IR", - ":MemRefDialect", - ":Support", - "//llvm:Support", - ], -) - cc_library( name = "AffineBufferizableOpInterfaceImpl", srcs = [ @@ -6605,7 +6582,7 @@ includes = ["include"], deps = [ ":Affine", - ":BufferizableOpInterface", + ":BufferizationDialect", "//llvm:Support", ], ) @@ -6621,7 +6598,7 @@ includes = ["include"], deps = [ ":ArithmeticDialect", - ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":MemRefDialect", ":Support", @@ -6640,7 +6617,6 @@ ], includes = ["include"], deps = [ - ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationTransforms", ":IR", @@ -6660,7 +6636,6 @@ ], includes = ["include"], deps = [ - ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationTransforms", ":IR", @@ -6680,7 +6655,7 @@ ], includes = ["include"], deps = [ - ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":StandardOps", ":Support", @@ -6698,7 +6673,7 @@ ], includes = ["include"], deps = [ - ":BufferizableOpInterface", + ":BufferizationDialect", ":IR", ":Support", ":VectorOps", @@ -6827,7 +6802,7 @@ deps = [ ":Affine", ":ArithmeticDialect", - ":BufferizableOpInterface", + ":BufferizationDialect", ":CopyOpInterface", ":DialectUtils", ":IR", @@ -6909,7 +6884,6 @@ ":Analysis", ":ArithBufferizableOpInterfaceImpl", ":ArithmeticDialect", - ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationTransforms", ":ComplexDialect", @@ -6953,7 +6927,6 @@ ], includes = ["include"], deps = [ - ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationTransforms", ":DialectUtils", @@ -7968,19 +7941,27 @@ ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/Bufferization/IR/BufferizationOps.td", - deps = [":BufferizationOpsTdFiles"], + deps = [ + ":BufferizableOpInterfaceTdFiles", + ":BufferizationOpsTdFiles", + ], ) cc_library( name = "BufferizationDialect", srcs = [ + "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp", "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp", "lib/Dialect/Bufferization/IR/BufferizationOps.cpp", ], - hdrs = ["include/mlir/Dialect/Bufferization/IR/Bufferization.h"], + hdrs = [ + "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", + "include/mlir/Dialect/Bufferization/IR/Bufferization.h", + ], includes = ["include"], deps = [ ":AllocationOpInterface", + ":BufferizableOpInterfaceIncGen", ":BufferizationBaseIncGen", ":BufferizationOpsIncGen", ":ControlFlowInterfaces", @@ -7989,6 +7970,7 @@ ":InferTypeOpInterface", ":MemRefDialect", ":StandardOps", + ":Support", ":TensorDialect", ":ViewLikeInterface", "//llvm:Support", @@ -8025,7 +8007,6 @@ deps = [ ":AllocationOpInterface", ":Analysis", - ":BufferizableOpInterface", ":BufferizationDialect", ":BufferizationPassIncGen", ":ControlFlowInterfaces", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -388,7 +388,6 @@ "//mlir:AffineBufferizableOpInterfaceImpl", "//mlir:ArithBufferizableOpInterfaceImpl", "//mlir:ArithmeticDialect", - "//mlir:BufferizableOpInterface", "//mlir:BufferizationDialect", "//mlir:BufferizationTransforms", "//mlir:GPUDialect",