diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H - -namespace mlir { - -class DialectRegistry; - -namespace linalg { -namespace comprehensive_bufferize { -namespace std_ext { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace std_ext -} // namespace comprehensive_bufferize -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,18 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h @@ -23,10 +23,6 @@ class RewritePatternSet; -void populateStdBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); - /// Creates an instance of std bufferization pass. std::unique_ptr createStdBufferizePass(); diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td @@ -14,8 +14,6 @@ def StdBufferize : Pass<"std-bufferize", "FuncOp"> { let summary = "Bufferize the std dialect"; let constructor = "mlir::createStdBufferizePass()"; - let dependentDialects = ["bufferization::BufferizationDialect", - "memref::MemRefDialect", "scf::SCFDialect"]; } def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -25,14 +25,6 @@ MLIRTensor ) -add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl - StdInterfaceImpl.cpp - - LINK_LIBS PUBLIC - MLIRBufferization - MLIRStandard -) - add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl VectorInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -49,7 +49,6 @@ MLIRSCF MLIRSCFTransforms MLIRSCFUtils - MLIRStdBufferizableOpInterfaceImpl MLIRPass MLIRStandard MLIRStandardOpsTransforms diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -15,10 +15,10 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -56,7 +56,7 @@ linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); - std_ext::registerBufferizableOpInterfaceExternalModels(registry); + mlir::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp rename to mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===// +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,19 +6,18 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" +using namespace mlir; using namespace mlir::bufferization; namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { -namespace std_ext { +namespace { /// Bufferization of std.select. Just replace the operands. struct SelectOpInterface @@ -69,12 +68,10 @@ } }; -} // namespace std_ext -} // namespace comprehensive_bufferize -} // namespace linalg +} // namespace } // namespace mlir -void mlir::linalg::comprehensive_bufferize::std_ext:: - registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); +void mlir::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); } diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -12,64 +12,34 @@ #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BlockAndValueMapping.h" -#include "mlir/Transforms/DialectConversion.h" using namespace mlir; - -namespace { -class BufferizeSelectOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(SelectOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!op.getCondition().getType().isa()) - return rewriter.notifyMatchFailure(op, "requires scalar condition"); - - rewriter.replaceOpWithNewOp(op, adaptor.getCondition(), - adaptor.getTrueValue(), - adaptor.getFalseValue()); - return success(); - } -}; -} // namespace - -void mlir::populateStdBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); -} +using namespace mlir::bufferization; namespace { struct StdBufferizePass : public StdBufferizeBase { void runOnOperation() override { - auto *context = &getContext(); - bufferization::BufferizeTypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - target.addLegalDialect(); + std::unique_ptr options = + getPartialBufferizationOptions(); + options->addToDialectFilter(); - populateStdBufferizePatterns(typeConverter, patterns); - // We only bufferize the case of tensor selected type and scalar condition, - // as that boils down to a select over memref descriptors (don't need to - // touch the data). - target.addDynamicallyLegalOp([&](SelectOp op) { - return typeConverter.isLegal(op.getType()) || - !op.getCondition().getType().isa(); - }); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), *options))) signalPassFailure(); } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + mlir::registerBufferizableOpInterfaceExternalModels(registry); + } }; } // namespace diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRStandardOpsTransforms + BufferizableOpInterfaceImpl.cpp Bufferize.cpp DecomposeCallGraphTypes.cpp FuncBufferize.cpp @@ -13,6 +14,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRArithmeticTransforms + MLIRBufferization MLIRBufferizationTransforms MLIRIR MLIRMemRef diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -27,8 +27,8 @@ MLIRPass MLIRSCF MLIRSCFTransforms - MLIRStdBufferizableOpInterfaceImpl MLIRStandard + MLIRStandardOpsTransforms MLIRTensor MLIRTensorTransforms MLIRTransformUtils diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -18,12 +18,12 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/PassManager.h" @@ -62,7 +62,7 @@ arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); - std_ext::registerBufferizableOpInterfaceExternalModels(registry); + mlir::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6658,24 +6658,6 @@ ], ) -cc_library( - name = "StdBufferizableOpInterfaceImpl", - srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h", - ], - includes = ["include"], - deps = [ - ":BufferizationDialect", - ":IR", - ":StandardOps", - ":Support", - "//llvm:Support", - ], -) - cc_library( name = "VectorBufferizableOpInterfaceImpl", srcs = [ @@ -6916,7 +6898,6 @@ ":SCFUtils", ":StandardOps", ":StandardOpsTransforms", - ":StdBufferizableOpInterfaceImpl", ":Support", ":TensorDialect", ":TensorTransforms", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -403,7 +403,7 @@ "//mlir:SCFDialect", "//mlir:SCFTransforms", "//mlir:StandardOps", - "//mlir:StdBufferizableOpInterfaceImpl", + "//mlir:StandardOpsTransforms", "//mlir:TensorDialect", "//mlir:TensorTransforms", "//mlir:TransformUtils",