diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h @@ -0,0 +1,27 @@ +//===- ArithInterfaceImpl.h - Arith Impl. of BufferizableOpInterface ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H +#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace linalg { +namespace comprehensive_bufferize { +namespace arith_ext { + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +} // namespace arith_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_ARITH_INTERFACE_IMPL_H diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -0,0 +1,73 @@ +//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Transforms/BufferUtils.h" + +namespace mlir { +namespace linalg { +namespace comprehensive_bufferize { +namespace arith_ext { + +struct ConstantOpInterface + : public BufferizableOpInterface::ExternalModel { + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {}; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto constantOp = cast(op); + if (!constantOp.getResult().getType().isa()) + return success(); + assert(constantOp.getType().dyn_cast() && + "not a constant ranked tensor"); + auto moduleOp = constantOp->getParentOfType(); + if (!moduleOp) { + return constantOp.emitError( + "cannot bufferize constants not within builtin.module op"); + } + GlobalCreator globalCreator(moduleOp); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(constantOp); + + auto globalMemref = globalCreator.getGlobalFor(constantOp); + Value memref = b.create( + constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); + state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); + state.mapBuffer(constantOp, memref); + + return success(); + } + + bool isWritable(Operation *op, Value value) const { + // Memory locations returned by memref::GetGlobalOp may not be written to. + assert(value.isa()); + return false; + } +}; + +} // namespace arith_ext +} // namespace comprehensive_bufferize +} // namespace linalg +} // namespace mlir + +void mlir::linalg::comprehensive_bufferize::arith_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -1,4 +1,5 @@ set(LLVM_OPTIONAL_SOURCES + ArithInterfaceImpl.cpp BufferizableOpInterface.cpp ComprehensiveBufferize.cpp LinalgInterfaceImpl.cpp @@ -17,6 +18,17 @@ MLIRMemRef ) +add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl + ArithInterfaceImpl.cpp + + LINK_LIBS PUBLIC + MLIRArithmetic + MLIRBufferizableOpInterface + MLIRIR + MLIRMemRef + MLIRStandardOpsTransforms +) + add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl LinalgInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -116,16 +116,17 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/BufferUtils.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "comprehensive-module-bufferize" @@ -1287,52 +1288,6 @@ namespace mlir { namespace linalg { namespace comprehensive_bufferize { -namespace arith_ext { - -struct ConstantOpInterface - : public BufferizableOpInterface::ExternalModel { - SmallVector getAliasingOpOperand(Operation *op, - OpResult opResult) const { - return {}; - } - - LogicalResult bufferize(Operation *op, OpBuilder &b, - BufferizationState &state) const { - auto constantOp = cast(op); - if (!isaTensor(constantOp.getResult().getType())) - return success(); - assert(constantOp.getType().dyn_cast() && - "not a constant ranked tensor"); - auto moduleOp = constantOp->getParentOfType(); - if (!moduleOp) { - return constantOp.emitError( - "cannot bufferize constants not within builtin.module op"); - } - GlobalCreator globalCreator(moduleOp); - - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(constantOp); - - auto globalMemref = globalCreator.getGlobalFor(constantOp); - Value memref = b.create( - constantOp.getLoc(), globalMemref.type(), globalMemref.getName()); - state.aliasInfo.insertNewBufferEquivalence(memref, constantOp.getResult()); - state.mapBuffer(constantOp, memref); - - return success(); - } - - bool isWritable(Operation *op, Value value) const { - // Memory locations returned by memref::GetGlobalOp may not be written to. - assert(value.isa()); - return false; - } -}; - -} // namespace arith_ext - namespace scf_ext { struct ExecuteRegionOpInterface @@ -1813,7 +1768,6 @@ } // namespace std_ext void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -31,6 +31,7 @@ MLIRAffine MLIRAffineUtils MLIRAnalysis + MLIRArithBufferizableOpInterfaceImpl MLIRArithmetic MLIRBufferizableOpInterface MLIRComplex diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h" #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h" @@ -39,6 +40,7 @@ tensor::TensorDialect, vector::VectorDialect, scf::SCFDialect, arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>(); registerBufferizableOpInterfaceExternalModels(registry); + arith_ext::registerBufferizableOpInterfaceExternalModels(registry); linalg_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor_ext::registerBufferizableOpInterfaceExternalModels(registry); vector_ext::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6306,6 +6306,26 @@ ], ) +cc_library( + name = "ArithBufferizableOpInterfaceImpl", + srcs = [ + "lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h", + ], + includes = ["include"], + deps = [ + ":ArithmeticDialect", + ":BufferizableOpInterface", + ":IR", + ":MemRefDialect", + ":Support", + ":TransformUtils", + "//llvm:Support", + ], +) + cc_library( name = "LinalgBufferizableOpInterfaceImpl", srcs = [ @@ -6563,6 +6583,7 @@ ":Affine", ":AffineUtils", ":Analysis", + ":ArithBufferizableOpInterfaceImpl", ":ArithmeticDialect", ":BufferizableOpInterface", ":ComplexDialect", @@ -6604,7 +6625,6 @@ includes = ["include"], deps = [ ":Affine", - ":ArithmeticDialect", ":BufferizableOpInterface", ":DialectUtils", ":IR", @@ -6614,7 +6634,6 @@ ":SCFDialect", ":StandardOps", ":Support", - ":TransformUtils", "//llvm:Support", ], )