diff --git a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(IR) +add_subdirectory(Transforms) set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Linalg) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h @@ -0,0 +1,36 @@ +//===- BufferizableOpInterface.h - Comprehensive Bufferize ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE_H_ +#define MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class BlockAndValueMapping; + +namespace linalg { +class AllocationCallbacks; +class BufferizationAliasInfo; + +/// Specify fine-grain relationship between buffers to enable more analysis. +enum class BufferRelation { + None, + // TODO: ResultContainsOperand, + // TODO: OperandContainsResult, + Equivalent +}; +} // namespace linalg +} // namespace mlir + +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h.inc" + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td @@ -0,0 +1,183 @@ +//===-- BufferizableOpInterface.td - Compreh. Bufferize ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE +#define MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE + +include "mlir/IR/OpBase.td" + +def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { + let description = [{ + An op interface for Comprehensive Bufferization. Ops that implement this + interface can be bufferized using Comprehensive Bufferization. + }]; + let cppNamespace = "::mlir::linalg"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpOperand bufferizes to a memory read. This + method will never be called on OpOperands that do not have a tensor + type. + + Note: It is always safe to consider an OpOperand as a memory read, + even if it does actually not read; however, this can introduce + unnecessary out-of-place bufferization decisions. The analysis of + Comprehensive Bufferize considers OpOperands of unknown ops (that do + not implement this interface) as reading OpOperands. + }], + /*retType=*/"bool", + /*methodName=*/"bufferizesToMemoryRead", + /*args=*/(ins "OpOperand &":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Does not have to be implemented for ops without tensor OpOperands. + llvm_unreachable("bufferizesToMemoryRead not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpOperand bufferizes to a memory write. + This method will never be called on OpOperands that do not have a + tensor type. + + Note: It is always safe to consider an OpOperand as a memory write, + even if it does actually not write; however, this can introduce + unnecessary out-of-place bufferization decisions. The analysis of + Comprehensive Bufferize considers OpOperands of unknown ops (that do + not implement this interface) as writing OpOperands. + }], + /*retType=*/"bool", + /*methodName=*/"bufferizesToMemoryWrite", + /*args=*/(ins "OpOperand &":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Does not have to be implemented for ops without tensor OpOperands. + llvm_unreachable("bufferizesToMemoryWrite not implemented"); + }] + >, + // TODO: Simplify this interface by removing `bufferizesToAliasOnly` and + // `getInplaceableOpResult`. Instead, always use `getAliasingOpResult`. If + // `getAliasingOpResult` returns a non-null value, we know that an alias + // is created. If `bufferizesToMemoryRead` and `bufferizesToMemoryWrite` + // return `false`, we know that the operands "bufferizes to alias only". + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given OpOperand creates an alias but does neither + read nor write. This implies that `bufferizesToMemoryRead` and + `bufferizesToMemoryWrite` must return `false`. This method will never + be called on OpOperands that do not have a tensor type. + + Examples of such ops are `tensor.extract_slice` and `tensor.cast`. + }], + /*retType=*/"bool", + /*methodName=*/"bufferizesToAliasOnly", + /*args=*/(ins "OpOperand &":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // For better debugging, run `bufferizesToMemoryWrite`, which fires an + // assertion when called on an op that should not have tensor + // OpOperands. + (void) cast($_op.getOperation()) + .bufferizesToMemoryWrite(opOperand); + // Return `false` by default, as most ops are not "alias only". + return false; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the OpResult that can bufferize in-place with a given + OpOperand. Return a null value if the OpOperand cannot bufferize + in-place. This method will never be called on OpOperands that do not + have a tensor type. + }], + /*retType=*/"OpResult", + /*methodName=*/"getInplaceableOpResult", + /*args=*/(ins "OpOperand &":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Does not have to be implemented for ops without tensor OpOperands. + llvm_unreachable("getInplaceableOpResult not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the OpResult that aliases with a given OpOperand when + bufferized in-place. This is a superset of `getInplaceableOpResult`. + This method will never be called on OpOperands that do not have a + tensor type. + }], + /*retType=*/"OpResult", + /*methodName=*/"getAliasingOpResult", + /*args=*/(ins "OpOperand &":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return cast($_op.getOperation()) + .getInplaceableOpResult(opOperand); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the OpOperands that alias with a given OpResult when + bufferized in-place. This method will never be called on OpResults + that do not have a tensor type. + + Note: This method can return multiple OpOperands, indicating that the + given OpResult may at runtime alias with any of the OpOperands. This + is useful for branches and for ops such as `std.select`. + }], + /*retType=*/"SmallVector", + /*methodName=*/"getAliasingOpOperand", + /*args=*/(ins "OpResult":$opResult), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Does not have to be implemented for ops without tensor OpResults. + llvm_unreachable("getInplaceableOpResult not implemented"); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the buffer relation between the given OpOperand and its + aliasing OpResult when bufferized in-place. Most OpOperands have an + "equivalence" relation. + + TODO: Support other relations such as "OpOperand is included in + OpResult". + }], + /*retType=*/"BufferRelation", + /*methodName=*/"bufferRelation", + /*args=*/(ins "OpOperand &":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Does not have to be implemented for ops without tensor OpOperands. + llvm_unreachable("bufferRelation not implemented"); + }] + >, + // TODO: Simplify method signature: Pass an OpBuilder and a + // BufferizationState object. + InterfaceMethod< + /*desc=*/[{ + Bufferize this op, i.e., rewrite it into a memref-based equivalent. + `bvm` maps tensor values to memref values and this method should map + tensor results to memref results after creating/modifying ops. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"bufferize", + /*args=*/(ins "OpBuilder &":$b, + "BlockAndValueMapping &":$bvm, + "BufferizationAliasInfo &":$aliasInfo, + "AllocationCallbacks &":$allocationFn), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + llvm_unreachable("bufferize not implemented"); + return failure(); + }] + >, + ]; +} + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_BUFFERIZABLEOPINTERFACE diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS BufferizableOpInterface.td) +mlir_tablegen(BufferizableOpInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(BufferizableOpInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRBufferizableOpInterfaceIncGen) +add_dependencies(mlir-headers MLIRBufferizableOpInterfaceIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -34,14 +34,6 @@ /// uses BufferizationAliasInfo. class BufferizationAliasInfo { public: - /// Specify fine-grain relationship between buffers to enable more analysis. - enum class BufferRelation { - None, - // TODO: ResultContainsOperand, - // TODO: OperandContainsResult, - Equivalent - }; - explicit BufferizationAliasInfo(Operation *rootOp); /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp @@ -0,0 +1,17 @@ +//===- BufferizableOpInterface.cpp - Comprehensive Bufferize --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h" + +namespace mlir { +namespace linalg { + +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp.inc" + +} // namespace linalg +} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,3 +1,38 @@ +set(LLVM_OPTIONAL_SOURCES + Bufferize.cpp + BufferizableOpInterface.cpp + CodegenStrategy.cpp + ComprehensiveBufferize.cpp + Detensorize.cpp + Distribution.cpp + DropUnitDims.cpp + ElementwiseOpFusion.cpp + ElementwiseToLinalg.cpp + Fusion.cpp + FusionOnTensors.cpp + Generalization.cpp + Hoisting.cpp + HoistPadding.cpp + InlineScalarOperands.cpp + Interchange.cpp + Loops.cpp + LinalgStrategyPasses.cpp + Promotion.cpp + Tiling.cpp + Transforms.cpp + Vectorization.cpp +) + +add_mlir_dialect_library(MLIRBufferizableOpInterface + BufferizableOpInterface.cpp + + DEPENDS + MLIRBufferizableOpInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR +) + add_mlir_dialect_library(MLIRLinalgTransforms Bufferize.cpp CodegenStrategy.cpp @@ -32,6 +67,7 @@ MLIRAffineUtils MLIRAnalysis MLIRArithmetic + MLIRBufferizableOpInterface MLIRComplex MLIRInferTypeOpInterface MLIRIR diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -112,6 +112,7 @@ #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -138,8 +139,6 @@ using namespace linalg; using namespace tensor; -using BufferRelation = BufferizationAliasInfo::BufferRelation; - #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") #define LDBG(X) LLVM_DEBUG(DBGS() << X) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6110,6 +6110,54 @@ deps = [":LinalgStructuredOpsTdFiles"], ) +td_library( + name = "BufferizableOpInterfaceTdFiles", + srcs = [ + "include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td", + ], + includes = ["include"], + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "BufferizableOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.td", + deps = [ + ":BufferizableOpInterfaceTdFiles", + ], +) + +cc_library( + name = "BufferizableOpInterface", + srcs = [ + "lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp", + ], + hdrs = [ + "include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterface.h", + ], + includes = ["include"], + deps = [ + ":BufferizableOpInterfaceIncGen", + ":IR", + ":Support", + "//llvm:Support", + ], +) + td_library( name = "LinalgDocTdFiles", srcs = ["include/mlir/Dialect/Linalg/IR/LinalgDoc.td"], @@ -6288,10 +6336,13 @@ cc_library( name = "LinalgTransforms", - srcs = glob([ - "lib/Dialect/Linalg/Transforms/*.cpp", - "lib/Dialect/Linalg/Transforms/*.h", - ]) + [ + srcs = glob( + [ + "lib/Dialect/Linalg/Transforms/*.cpp", + "lib/Dialect/Linalg/Transforms/*.h", + ], + exclude = ["lib/Dialect/Linalg/Transforms/BufferizableOpInterface.cpp"], + ) + [ "lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp", "lib/Dialect/Linalg/Utils/Utils.cpp", ], @@ -6311,6 +6362,7 @@ ":AffineUtils", ":Analysis", ":ArithmeticDialect", + ":BufferizableOpInterface", ":ComplexDialect", ":DialectUtils", ":IR",