diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -233,6 +233,18 @@ && static_cast( bufferizableOp.getAliasingOpResult(opOperand)); } + + // TODO: The following two attributes should belong to the tensor dialect. + // The corresponding verifier should also be in the tensor dialect. + /// Attribute name used to mark region arguments that can be bufferized + /// in-place during linalg comprehensive bufferization. + constexpr const static ::llvm::StringLiteral + kInplaceableAttrName = "linalg.inplaceable"; + + /// Attribute name used to mark the bufferization layout for region + /// arguments during linalg comprehensive bufferization. + constexpr const static ::llvm::StringLiteral + kBufferLayoutAttrName = "linalg.buffer_layout"; }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -48,16 +48,6 @@ constexpr const static ::llvm::StringLiteral kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps"; - /// Attribute name used to mark region arguments that can be bufferized - /// in-place during linalg comprehensive bufferization. - constexpr const static ::llvm::StringLiteral - kInplaceableAttrName = "linalg.inplaceable"; - - /// Attribute name used to mark the bufferization layout for region - /// arguments during linalg comprehensive bufferization. - constexpr const static ::llvm::StringLiteral - kBufferLayoutAttrName = "linalg.buffer_layout"; - using RegionBuilderFunType = llvm::function_ref; RegionBuilderFunType getRegionBuilder(StringRef name) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -207,7 +207,8 @@ /// `bbArg`. static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { auto funcOp = cast(bbArg.getOwner()->getParentOp()); - funcOp.setArgAttr(bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName, + funcOp.setArgAttr(bbArg.getArgNumber(), + BufferizableOpInterface::kInplaceableAttrName, BoolAttr::get(bbArg.getContext(), inPlace)); } @@ -216,9 +217,9 @@ static void removeBufferizationFuncArguments(BlockArgument bbArg) { auto funcOp = cast(bbArg.getOwner()->getParentOp()); funcOp.removeArgAttr(bbArg.getArgNumber(), - LinalgDialect::kBufferLayoutAttrName); + BufferizableOpInterface::kBufferLayoutAttrName); funcOp.removeArgAttr(bbArg.getArgNumber(), - LinalgDialect::kInplaceableAttrName); + BufferizableOpInterface::kInplaceableAttrName); } //===----------------------------------------------------------------------===// @@ -1112,7 +1113,7 @@ // bufferizing to a writeable memory. for (BlockArgument bbArg : funcOp.getArguments()) { BoolAttr inplaceAttr = funcOp.getArgAttrOfType( - bbArg.getArgNumber(), LinalgDialect::kInplaceableAttrName); + bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName); if (inplaceAttr && inplaceAttr.getValue()) aliasInfo.setBufferizesToWritableMemory(bbArg); } @@ -1145,7 +1146,7 @@ Location loc, Value from, Value to) { - b.create(loc, from, to); + b.create(loc, from, to); } LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp( @@ -1498,7 +1499,7 @@ Type inputType = it.value(); auto memrefType = inputType.dyn_cast(); auto layoutAttr = funcOp.getArgAttrOfType( - argNumber, LinalgDialect::kBufferLayoutAttrName); + argNumber, BufferizableOpInterface::kBufferLayoutAttrName); AffineMap desiredLayoutMap = layoutAttr ? layoutAttr.getValue() : AffineMap(); AffineMap currentLayoutMap = diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -62,14 +63,6 @@ constexpr const ::llvm::StringLiteral LinalgDialect::kMemoizedIndexingMapsAttrName; -/// Attribute name used to mark the bufferization layout for region -/// arguments during linalg comprehensive bufferization. -constexpr const ::llvm::StringLiteral LinalgDialect::kBufferLayoutAttrName; - -/// Attribute name used to mark region arguments that can be bufferized -/// in-place during linalg comprehensive bufferization. -constexpr const ::llvm::StringLiteral LinalgDialect::kInplaceableAttrName; - /// Trait to check if T provides a `regionBuilder` method. template using has_region_builder = decltype(T::regionBuilder); @@ -147,20 +140,24 @@ LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - if (attr.first == LinalgDialect::kInplaceableAttrName) { + using comprehensive_bufferize::BufferizableOpInterface; + + if (attr.first == BufferizableOpInterface::kInplaceableAttrName) { if (!attr.second.isa()) { - return op->emitError() << "'" << LinalgDialect::kInplaceableAttrName - << "' is expected to be a boolean attribute"; + return op->emitError() + << "'" << BufferizableOpInterface::kInplaceableAttrName + << "' is expected to be a boolean attribute"; } if (!op->hasTrait()) return op->emitError() << "expected " << attr.first << " to be used on function-like operations"; return success(); } - if (attr.first == LinalgDialect::kBufferLayoutAttrName) { + if (attr.first == BufferizableOpInterface::kBufferLayoutAttrName) { if (!attr.second.isa()) { - return op->emitError() << "'" << LinalgDialect::kBufferLayoutAttrName - << "' is expected to be a affine map attribute"; + return op->emitError() + << "'" << BufferizableOpInterface::kBufferLayoutAttrName + << "' is expected to be a affine map attribute"; } if (!op->hasTrait()) return op->emitError() << "expected " << attr.first diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -61,6 +61,12 @@ options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc, Value v) {}; } + // TODO: Change to memref::CopyOp (default memCpyFn). + options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from, + Value to) { + b.create(loc, from, to); + }; + options.allowReturnMemref = allowReturnMemref; options.analysisFuzzerSeed = analysisFuzzerSeed; options.testAnalysisOnly = testAnalysisOnly; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -266,7 +267,9 @@ if (auto arg = val.dyn_cast()) if (auto funcOp = dyn_cast(arg.getOwner()->getParentOp())) if (auto attr = funcOp.getArgAttrOfType( - arg.getArgNumber(), linalg::LinalgDialect::kInplaceableAttrName)) + arg.getArgNumber(), + linalg::comprehensive_bufferize::BufferizableOpInterface:: + kInplaceableAttrName)) return attr.getValue(); return false; } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1712,6 +1712,7 @@ deps = [ ":Affine", ":ArithmeticDialect", + ":BufferizableOpInterface", ":IR", ":LLVMDialect", ":LinalgOps", @@ -6296,6 +6297,7 @@ deps = [ ":Affine", ":ArithmeticDialect", + ":BufferizableOpInterface", ":CopyOpInterface", ":DialectUtils", ":IR",