diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,21 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace arith { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h @@ -12,17 +12,8 @@ #include "mlir/Pass/Pass.h" namespace mlir { -namespace bufferization { -class BufferizeTypeConverter; -} // namespace bufferization - namespace arith { -/// Add patterns to bufferize Arithmetic ops. -void populateArithmeticBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns); - /// Create a pass to bufferize Arithmetic ops. std::unique_ptr createArithmeticBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td @@ -14,8 +14,6 @@ def ArithmeticBufferize : Pass<"arith-bufferize", "FuncOp"> { let summary = "Bufferize Arithmetic dialect ops."; let constructor = "mlir::arith::createArithmeticBufferizePass()"; - let dependentDialects = ["bufferization::BufferizationDialect", - "memref::MemRefDialect"]; } def ArithmeticExpandOps : Pass<"arith-expand", "FuncOp"> { diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H - -namespace mlir { - -class DialectRegistry; - -namespace linalg { -namespace comprehensive_bufferize { -namespace arith_ext { - -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); - -} // namespace arith_ext -} // namespace comprehensive_bufferize -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITHINTERFACEIMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp rename from mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp rename to mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===// +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" - +#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" @@ -18,9 +17,8 @@ using namespace mlir::bufferization; namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { -namespace arith_ext { +namespace arith { +namespace { /// Bufferization of arith.constant. Replace with memref.get_global. struct ConstantOpInterface @@ -100,14 +98,13 @@ return success(); } }; -} // namespace arith_ext -} // namespace comprehensive_bufferize -} // namespace linalg + +} // namespace +} // namespace arith } // namespace mlir -void mlir::linalg::comprehensive_bufferize::arith_ext:: - registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); - registry - .addOpInterface(); +void mlir::arith::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); } diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp @@ -8,61 +8,37 @@ #include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" using namespace mlir; +using namespace bufferization; namespace { - -/// Bufferize arith.index_cast. -struct BufferizeIndexCastOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto tensorType = op.getType().cast(); - rewriter.replaceOpWithNewOp( - op, adaptor.getIn(), - MemRefType::get(tensorType.getShape(), tensorType.getElementType())); - return success(); - } -}; - /// Pass to bufferize Arithmetic ops. struct ArithmeticBufferizePass : public ArithmeticBufferizeBase { void runOnOperation() override { - bufferization::BufferizeTypeConverter typeConverter; - RewritePatternSet patterns(&getContext()); - ConversionTarget target(getContext()); - - target.addLegalDialect(); - - arith::populateArithmeticBufferizePatterns(typeConverter, patterns); + std::unique_ptr options = + getPartialBufferizationOptions(); + options->addToDialectFilter(); - target.addDynamicallyLegalOp( - [&](arith::IndexCastOp op) { - return typeConverter.isLegal(op.getType()); - }); - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), *options))) signalPassFailure(); } -}; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + arith::registerBufferizableOpInterfaceExternalModels(registry); + } +}; } // namespace -void mlir::arith::populateArithmeticBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); -} - std::unique_ptr mlir::arith::createArithmeticBufferizePass() { return std::make_unique(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -8,11 +8,11 @@ #include "PassDetail.h" +#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" @@ -52,7 +52,7 @@ vector::VectorDialect, scf::SCFDialect, arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); affine_ext::registerBufferizableOpInterfaceExternalModels(registry); - arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerModuleBufferizationExternalModels(registry); diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -96,19 +96,3 @@ } return %5: tensor } - -// ----- - -// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)> -// CHECK-LABEL: func @index_cast( -// CHECK-SAME: %[[TENSOR:.*]]: tensor, %[[SCALAR:.*]]: i32 -func @index_cast(%tensor: tensor, %scalar: i32) -> (tensor, index) { - %index_tensor = arith.index_cast %tensor : tensor to tensor - %index_scalar = arith.index_cast %scalar : i32 to index - return %index_tensor, %index_scalar : tensor, index -} -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref -// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]] -// CHECK-SAME: memref to memref -// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]] -// CHECK: return %[[INDEX_TENSOR]] diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -12,11 +12,11 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h" @@ -59,7 +59,7 @@ vector::VectorDialect, scf::SCFDialect, StandardOpsDialect, arith::ArithmeticDialect, AffineDialect>(); affine_ext::registerBufferizableOpInterfaceExternalModels(registry); - arith_ext::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); scf_ext::registerBufferizableOpInterfaceExternalModels(registry); std_ext::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6579,27 +6579,6 @@ ], ) -cc_library( - name = "ArithBufferizableOpInterfaceImpl", - srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h", - ], - includes = ["include"], - deps = [ - ":ArithmeticDialect", - ":BufferizationDialect", - ":BufferizationTransforms", - ":IR", - ":MemRefDialect", - ":Support", - ":TransformUtils", - "//llvm:Support", - ], -) - cc_library( name = "LinalgBufferizableOpInterfaceImpl", srcs = [ @@ -6875,8 +6854,8 @@ ":AffineBufferizableOpInterfaceImpl", ":AffineUtils", ":Analysis", - ":ArithBufferizableOpInterfaceImpl", ":ArithmeticDialect", + ":ArithmeticTransforms", ":BufferizationDialect", ":BufferizationTransforms", ":ComplexDialect", @@ -7565,7 +7544,10 @@ "lib/Dialect/Arithmetic/Transforms/*.cpp", "lib/Dialect/Arithmetic/Transforms/*.h", ]), - hdrs = ["include/mlir/Dialect/Arithmetic/Transforms/Passes.h"], + hdrs = [ + "include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h", + "include/mlir/Dialect/Arithmetic/Transforms/Passes.h", + ], includes = ["include"], deps = [ ":ArithmeticDialect", @@ -7576,7 +7558,10 @@ ":MemRefDialect", ":Pass", ":StandardOps", + ":Support", + ":TransformUtils", ":Transforms", + "//llvm:Support", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -388,8 +388,8 @@ "//llvm:Support", "//mlir:Affine", "//mlir:AffineBufferizableOpInterfaceImpl", - "//mlir:ArithBufferizableOpInterfaceImpl", "//mlir:ArithmeticDialect", + "//mlir:ArithmeticTransforms", "//mlir:BufferizationDialect", "//mlir:BufferizationTransforms", "//mlir:GPUDialect",