diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h rename from mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h rename to mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h @@ -1,4 +1,4 @@ -//===- SCFInterfaceImpl.h - SCF Impl. of BufferizableOpInterface ----------===// +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,19 +6,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H +#ifndef MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" namespace mlir { - class DialectRegistry; -namespace linalg { -namespace comprehensive_bufferize { -namespace scf_ext { - +namespace scf { /// Assert that yielded values of an scf.for op are aliasing their corresponding /// bbArgs. This is required because the i-th OpResult of an scf.for op is /// currently assumed to alias with the i-th iter_arg (in the absence of @@ -30,10 +26,7 @@ }; void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace scf_ext -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace scf } // namespace mlir -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_SCFINTERFACEIMPL_H +#endif // MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -2,7 +2,6 @@ AffineInterfaceImpl.cpp LinalgInterfaceImpl.cpp ModuleBufferization.cpp - SCFInterfaceImpl.cpp StdInterfaceImpl.cpp VectorInterfaceImpl.cpp ) @@ -26,16 +25,6 @@ MLIRTensor ) -add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl - SCFInterfaceImpl.cpp - - LINK_LIBS PUBLIC - MLIRBufferization - MLIRBufferizationTransforms - MLIRIR - MLIRSCF -) - add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl StdInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -47,7 +47,6 @@ MLIRLinalgUtils MLIRModuleBufferization MLIRSCF - MLIRSCFBufferizableOpInterfaceImpl MLIRSCFTransforms MLIRSCFUtils MLIRStdBufferizableOpInterfaceImpl diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -15,10 +15,10 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -54,7 +54,7 @@ affine_ext::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); - scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); @@ -132,7 +132,7 @@ } // Only certain scf.for ops are supported by the analysis. - options->addPostAnalysisStep(); + options->addPostAnalysisStep(); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp rename to mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- SCFInterfaceImpl.cpp - SCF Impl. of BufferizableOpInterface --------===// +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" +#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" + #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/SCF/SCF.h" @@ -14,12 +15,13 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" +using namespace mlir; using namespace mlir::bufferization; +using namespace mlir::scf; namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { -namespace scf_ext { +namespace scf { +namespace { // bufferization.to_memref is not allowed to change the rank. static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { @@ -384,42 +386,6 @@ } }; -LogicalResult -mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties:: - run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - LogicalResult status = success(); - - op->walk([&](scf::ForOp forOp) { - auto yieldOp = - cast(forOp.getLoopBody().front().getTerminator()); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - // Note: This is overly strict. We should check for aliasing bufferized - // values. But we don't have a "must-alias" analysis yet. - if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - status = - yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to a buffer that is aliasing the matching" - << " enclosing scf::for operand"; - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - - return status; -} - /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so /// this is for analysis only. struct YieldOpInterface @@ -462,18 +428,51 @@ } }; -} // namespace scf_ext -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace +} // namespace scf } // namespace mlir -void mlir::linalg::comprehensive_bufferize::scf_ext:: - registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface>(); +LogicalResult mlir::scf::AssertScfForAliasingProperties::run( + Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) { + LogicalResult status = success(); + + op->walk([&](scf::ForOp forOp) { + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + // Note: This is overly strict. We should check for aliasing bufferized + // values. But we don't have a "must-alias" analysis yet. + if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + status = + yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to a buffer that is aliasing the matching" + << " enclosing scf::for operand"; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + + return status; +} + +void mlir::scf::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); + registry + .addOpInterface>(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRSCFTransforms + BufferizableOpInterfaceImpl.cpp Bufferize.cpp ForToWhile.cpp LoopCanonicalization.cpp @@ -20,6 +21,7 @@ MLIRAffine MLIRAffineAnalysis MLIRArithmetic + MLIRBufferization MLIRBufferizationTransforms MLIRDialectUtils MLIRIR diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -26,7 +26,7 @@ MLIRMemRef MLIRPass MLIRSCF - MLIRSCFBufferizableOpInterfaceImpl + MLIRSCFTransforms MLIRStdBufferizableOpInterfaceImpl MLIRStandard MLIRTensor diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -18,11 +18,11 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/VectorOps.h" @@ -61,7 +61,7 @@ affine_ext::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); - scf_ext::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); @@ -106,7 +106,7 @@ auto options = std::make_unique(); if (!allowReturnMemref) - options->addPostAnalysisStep(); + options->addPostAnalysisStep(); options->allowReturnMemref = allowReturnMemref; options->allowUnknownOps = allowUnknownOps; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1782,6 +1782,7 @@ "lib/Dialect/SCF/Transforms/*.h", ]), hdrs = [ + "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/SCF/Passes.h", "include/mlir/Dialect/SCF/Transforms.h", ], @@ -2435,6 +2436,7 @@ "include/mlir/Dialect/SCF/*.h", ], exclude = [ + "include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/SCF/Transforms.h", ], ), @@ -6656,25 +6658,6 @@ ], ) -cc_library( - name = "SCFBufferizableOpInterfaceImpl", - srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h", - ], - includes = ["include"], - deps = [ - ":BufferizationDialect", - ":BufferizationTransforms", - ":IR", - ":SCFDialect", - ":Support", - "//llvm:Support", - ], -) - cc_library( name = "StdBufferizableOpInterfaceImpl", srcs = [ @@ -6928,7 +6911,6 @@ ":MemRefDialect", ":ModuleBufferization", ":Pass", - ":SCFBufferizableOpInterfaceImpl", ":SCFDialect", ":SCFTransforms", ":SCFUtils", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -400,7 +400,6 @@ "//mlir:LinalgTransforms", "//mlir:MemRefDialect", "//mlir:Pass", - "//mlir:SCFBufferizableOpInterfaceImpl", "//mlir:SCFDialect", "//mlir:SCFTransforms", "//mlir:StandardOps",