diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -326,17 +326,12 @@ && !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); } - // TODO: The following two attributes should belong to the tensor dialect. - // The corresponding verifier should also be in the tensor dialect. + // TODO: This attribute is deprecated. Use `bufferization.writable` or add + // a new attribute in a different dialect. /// Attribute name used to mark region arguments that can be bufferized /// in-place during one-shot bufferization. constexpr const static ::llvm::StringLiteral - kInplaceableAttrName = "linalg.inplaceable"; - - /// Attribute name used to mark the bufferization layout for region - /// arguments during one-shot bufferization. - constexpr const static ::llvm::StringLiteral - kBufferLayoutAttrName = "linalg.buffer_layout"; + kInplaceableAttrName = "linalg.inplaceable"; }]; } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td @@ -26,6 +26,19 @@ deallocation](/docs/BufferDeallocationInternals/). }]; let dependentDialects = ["memref::MemRefDialect", "tensor::TensorDialect"]; + + let extraClassDeclaration = [{ + /// An attribute that can override writability of buffers of tensor function + /// arguments during One-Shot Module Bufferize. + constexpr const static ::llvm::StringLiteral + kWritableAttrName = "bufferization.writable"; + + /// Attribute name used to mark the bufferization layout for region + /// arguments during One-Shot Module Bufferize. + constexpr const static ::llvm::StringLiteral + kBufferLayoutAttrName = "bufferization.buffer_layout"; + }]; + let hasOperationAttrVerify = 1; } #endif // BUFFERIZATION_BASE diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h @@ -0,0 +1,51 @@ +//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BUFFERIZATION_TRANSFORMS_FUNCBUFFERIZABLEOPINTERFACEIMPL_H +#define MLIR_BUFFERIZATION_TRANSFORMS_FUNCBUFFERIZABLEOPINTERFACEIMPL_H + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" + +namespace mlir { +class DialectRegistry; + +namespace bufferization { +namespace func_ext { +/// The state of analysis of a FuncOp. +enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed }; + +/// Extra analysis state that is required for bufferization of function +/// boundaries. +struct FuncAnalysisState : public DialectAnalysisState { + /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg + /// indices. + DenseMap> equivalentFuncArgs; + + /// A mapping of ReturnOp OpOperand indices to aliasing FuncOp BBArg indices. + DenseMap>> aliasingFuncArgs; + + /// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices. + DenseMap>> aliasingReturnVals; + + /// A set of all read BlockArguments of FuncOps. + DenseMap> readBbArgs; + + /// A set of all written-to BlockArguments of FuncOps. + DenseMap> writtenBbArgs; + + /// Keep track of which FuncOps are fully analyzed or currently being + /// analyzed. + DenseMap analyzedFuncOps; +}; + +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace func_ext +} // namespace bufferization +} // namespace mlir + +#endif // MLIR_BUFFERIZATION_TRANSFORMS_FUNCBUFFERIZABLEOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h @@ -0,0 +1,31 @@ +//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H + +namespace mlir { + +struct LogicalResult; +class ModuleOp; + +namespace bufferization { +struct OneShotBufferizationOptions; + +/// Run One-Shot Module Bufferization on the given module. Performs a simple +/// function call analysis to determine which function arguments are +/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot +/// Bufferize. +LogicalResult +runOneShotModuleBufferize(ModuleOp moduleOp, + bufferization::OneShotBufferizationOptions options); + +} // namespace bufferization +} // namespace mlir + +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td @@ -200,6 +200,34 @@ prints analysis results and explains why an OpOperand was decided to bufferize out-of-place. This is useful for understanding why One-Shot Bufferize chose to insert a certain buffer copy. + + `bufferize-function-boundaries` is an experimental flag for bufferizing + `FuncOp`, `ReturnOp` and `CallOp`. This feature is still under development + and supports only simple cases at the moment. In particular: + + * Recursive or circular function call graphs are not supported. + * If a newly allocated buffer is returned from a function (with + `allow-return-allocs`), the buffer will never be deallocated and leak. + Such IR needs special handling, e.g., allocation hoisting or reference + counting. + * External functions (without bodies) that return a tensor are not + supported. + * Function with multiple blocks or multiple ReturnOps are not supported. + + One-Shot Bufferize implements the following contract around function calls: + The buffer of function arguments is always writable (unless annotated with + `bufferization.writable = false`). A buffer copy may be inserted at the call + site where necessary. Alias sets and equivalence info is propagated through + function calls. Whenever a function is bufferized, all other functions that + are being called were already analyzed and bufferized, so exact alias and + equivalence information is available. This is why recursive function calls + are not yet supported. + + One-Shot Bufferize gathers additional information during the analysis phase + when function boundary bufferization is activated. E.g., whether a function + argument is read/written and which returned values are aliasing/equivalent. + For debugging purposes, such information can be printed with + `test-analysis-only`. }]; let options = [ Option<"allowReturnAllocs", "allow-return-allocs", "bool", @@ -211,6 +239,9 @@ Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", /*default=*/"0", "Test only: Analyze ops in random order with a given seed (fuzzer)">, + Option<"bufferizeFunctionBoundaries", "bufferize-function-boundaries", + "bool", /*default=*/"0", + "Bufferize function boundaries (experimental).">, Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true", "Specify if buffers should be deallocated. For compatibility with " "core bufferization passes.">, diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -# no targets defined here - diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h +++ /dev/null @@ -1,43 +0,0 @@ -//===- ModuleBufferization.h - Bufferization across Func. Boundaries ------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULEBUFFERIZATION_H -#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULEBUFFERIZATION_H - -#include - -namespace mlir { - -class DialectRegistry; -struct LogicalResult; -class ModuleOp; - -namespace bufferization { -struct OneShotBufferizationOptions; -} // namespace bufferization - -namespace linalg { -namespace comprehensive_bufferize { - -/// Run Module Bufferization on the given module. Performs a simple function -/// call analysis to determine which function arguments are inplaceable. Then -/// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize. -LogicalResult -runModuleBufferize(ModuleOp moduleOp, - bufferization::OneShotBufferizationOptions options); - -namespace std_ext { - -void registerModuleBufferizationExternalModels(DialectRegistry ®istry); - -} // namespace std_ext -} // namespace comprehensive_bufferize -} // namespace linalg -} // namespace mlir - -#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULEBUFFERIZATION_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -22,6 +22,7 @@ #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/DLTI/DLTI.h" @@ -44,6 +45,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" @@ -94,8 +96,11 @@ x86vector::X86VectorDialect>(); // clang-format on arith::registerBufferizableOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); + shape::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -32,11 +32,6 @@ using namespace mlir; using namespace bufferization; -/// Attribute name used to mark the bufferization layout for region -/// arguments during linalg comprehensive bufferization. -constexpr const ::llvm::StringLiteral - bufferization::BufferizableOpInterface::kBufferLayoutAttrName; - /// Attribute name used to mark region arguments that can be bufferized /// in-place during linalg comprehensive bufferization. constexpr const ::llvm::StringLiteral diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp @@ -14,6 +14,15 @@ #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc" +/// Attribute name used to mark function arguments who's buffers can be written +/// to during One-Shot Module Bufferize. +constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName; + +/// Attribute name used to mark the bufferization layout for region arguments +/// during One-Shot Module Bufferize. +constexpr const ::llvm::StringLiteral + BufferizationDialect::kBufferLayoutAttrName; + //===----------------------------------------------------------------------===// // Bufferization Dialect Interfaces //===----------------------------------------------------------------------===// @@ -41,3 +50,33 @@ >(); addInterfaces(); } + +LogicalResult +BufferizationDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + using bufferization::BufferizableOpInterface; + + if (attr.getName() == kWritableAttrName) { + if (!attr.getValue().isa()) { + return op->emitError() << "'" << kWritableAttrName + << "' is expected to be a boolean attribute"; + } + if (!isa(op)) + return op->emitError() << "expected " << attr.getName() + << " to be used on function-like operations"; + return success(); + } + if (attr.getName() == kBufferLayoutAttrName) { + if (!attr.getValue().isa()) { + return op->emitError() << "'" << kBufferLayoutAttrName + << "' is expected to be a affine map attribute"; + } + if (!isa(op)) + return op->emitError() << "expected " << attr.getName() + << " to be used on function-like operations"; + return success(); + } + + return op->emitError() << "attribute '" << attr.getName() + << "' not supported by the bufferization dialect"; +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Operation.h" @@ -178,8 +179,10 @@ BufferizationOptions::OpFilterEntry::FilterFn filterFn = [&](Operation *op) { // Disallow non-func dialect ops. I.e., no ops related to function - // calls. - if (isa(op->getDialect())) + // calls. (Unless explicitly activated.) + bool isFuncBoundaryOp = + isa(op->getDialect()) || isa(op); + if (!this->bufferizeFunctionBoundaries && isFuncBoundaryOp) return false; // Filter may be specified via options. if (this->dialectFilter.hasValue()) @@ -195,9 +198,16 @@ } ModuleOp moduleOp = getOperation(); - if (failed(runOneShotBufferize(moduleOp, opt))) { - signalPassFailure(); - return; + if (bufferizeFunctionBoundaries) { + if (failed(runOneShotModuleBufferize(moduleOp, opt))) { + signalPassFailure(); + return; + } + } else { + if (failed(runOneShotBufferize(moduleOp, opt))) { + signalPassFailure(); + return; + } } if (opt.testAnalysisOnly) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -4,7 +4,9 @@ BufferOptimizations.cpp BufferResultsToOutParams.cpp BufferUtils.cpp + FuncBufferizableOpInterfaceImpl.cpp OneShotAnalysis.cpp + OneShotModuleBufferize.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -0,0 +1,540 @@ +//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace bufferization { +namespace func_ext { + +/// Return the unique ReturnOp that terminates `funcOp`. +/// Return nullptr if there is no such unique ReturnOp. +static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { + func::ReturnOp returnOp; + for (Block &b : funcOp.getBody()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + +/// Return the index-th bufferized function argument type. This assumes that the +/// specified argument is a tensor. If the tensor is ranked, a layout map may be +/// specified by the user. If no layout map is specified, a fully dynamic map is +/// used. +static BaseMemRefType +getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, + const BufferizationOptions &options) { + auto tensorType = funcOp.getType().getInput(index).dyn_cast(); + assert(tensorType && "expected TensorType"); + BaseMemRefType memrefType = getMemRefType(tensorType, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizationDialect::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = memrefType.dyn_cast(); + assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); + return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + layoutAttr.getValue(), + memrefType.getMemorySpaceAsInt()); +} + +/// Return the FuncOp called by `callOp`. +static FuncOp getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + +/// Get FuncAnalysisState. +static const FuncAnalysisState & +getFuncAnalysisState(const AnalysisState &state) { + Optional maybeState = + state.getDialectState( + func::FuncDialect::getDialectNamespace()); + assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); + return **maybeState; +} + +/// Return the state (phase) of analysis of the FuncOp. +static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, + FuncOp funcOp) { + const FuncAnalysisState &funcState = getFuncAnalysisState(state); + auto it = funcState.analyzedFuncOps.find(funcOp); + if (it == funcState.analyzedFuncOps.end()) + return FuncOpAnalysisState::NotAnalyzed; + return it->second; +} + +/// Return the index of the bbArg in the given FuncOp that is equivalent to the +/// specified return value (if any). +static Optional getEquivalentFuncArgIdx(FuncOp funcOp, + const FuncAnalysisState &state, + int64_t returnValIdx) { + auto funcOpIt = state.equivalentFuncArgs.find(funcOp); + if (funcOpIt == state.equivalentFuncArgs.end()) + // No equivalence info stores for funcOp. + return None; + + auto retValIt = funcOpIt->getSecond().find(returnValIdx); + if (retValIt == funcOpIt->getSecond().end()) + // Return value has no equivalent bbArg. + return None; + + return retValIt->getSecond(); +} + +/// If `value` is a memref::CastOp, return its source. Otherwise, return +/// `value` directly. +static Value getNonCastedValue(Value value) { + while (auto castOp = value.getDefiningOp()) + value = castOp.source(); + return value; +} + +struct CallOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + func::CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + const FuncAnalysisState &funcState = getFuncAnalysisState(state); + if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) + // FuncOp not analyzed yet. Assume that OpOperand is read. + return true; + + auto it = funcState.readBbArgs.find(funcOp); + assert(it != funcState.readBbArgs.end() && + "expected analysis info for analyzed FuncOps"); + return it->second.contains(opOperand.getOperandNumber()); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + func::CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + + const FuncAnalysisState &funcState = getFuncAnalysisState(state); + if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) + // FuncOp not analyzed yet. Assume that OpOperand is written. + return true; + + auto it = funcState.writtenBbArgs.find(funcOp); + assert(it != funcState.writtenBbArgs.end() && + "expected analysis info for analyzed FuncOps"); + return it->second.contains(opOperand.getOperandNumber()); + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + func::CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + const FuncAnalysisState &funcState = getFuncAnalysisState(state); + if (getFuncOpAnalysisState(state, funcOp) != + FuncOpAnalysisState::Analyzed) { + // FuncOp not analyzed yet. Any OpResult may be aliasing. + SmallVector result; + for (OpResult opResult : op->getOpResults()) + if (opResult.getType().isa()) + result.push_back(opResult); + return result; + } + + // Get aliasing results from state. + auto mapping = funcState.aliasingReturnVals.find(funcOp); + assert(mapping != funcState.aliasingReturnVals.end() && + "expected analysis info analyzed FuncOps"); + auto aliasingReturnVals = + mapping->second.find(opOperand.getOperandNumber()); + if (aliasingReturnVals == mapping->second.end()) + return {}; + + SmallVector result; + for (int64_t resultIdx : aliasingReturnVals->second) + result.push_back(callOp->getOpResult(resultIdx)); + return result; + } + + SmallVector + getAliasingOpOperand(Operation *op, OpResult opResult, + const AnalysisState &state) const { + func::CallOp callOp = cast(op); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + const FuncAnalysisState &funcState = getFuncAnalysisState(state); + if (getFuncOpAnalysisState(state, funcOp) != + FuncOpAnalysisState::Analyzed) { + // FuncOp not analyzed yet. Any OpOperand may be aliasing. + SmallVector result; + for (OpOperand &opOperand : op->getOpOperands()) + if (opOperand.get().getType().isa()) + result.push_back(&opOperand); + return result; + } + + // Get aliasing bbArgs from state. + auto mapping = funcState.aliasingFuncArgs.find(funcOp); + assert(mapping != funcState.aliasingFuncArgs.end() && + "expected analysis info analyzed FuncOps"); + auto aliasingFuncArgs = mapping->second.find(opResult.getResultNumber()); + if (aliasingFuncArgs == mapping->second.end()) + return {}; + + SmallVector result; + for (int64_t bbArgIdx : aliasingFuncArgs->second) + result.push_back(&callOp->getOpOperand(bbArgIdx)); + return result; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } + + /// All function arguments are writable. It is the responsibility of the + /// CallOp to insert buffer copies where necessary. + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + func::CallOp callOp = cast(op); + unsigned numResults = callOp.getNumResults(); + unsigned numOperands = callOp->getNumOperands(); + FuncOp funcOp = getCalledFunction(callOp); + assert(funcOp && "expected CallOp to a FuncOp"); + const FuncAnalysisState &moduleState = + getFuncAnalysisState(state.getAnalysisState()); + const OneShotBufferizationOptions &options = + static_cast(state.getOptions()); + + // Result types of the bufferized CallOp. + SmallVector resultTypes; + // Replacement values for the existing CallOp. These are usually the results + // of the bufferized CallOp, unless a tensor result folds onto an operand. + SmallVector replacementValues(numResults, Value()); + // For non-tensor results: A mapping from return val indices of the old + // CallOp to return val indices of the bufferized CallOp. + SmallVector> retValMapping(numResults, None); + // Operands of the bufferized CallOp. + SmallVector newOperands(numOperands, Value()); + + // Based on previously gathered equivalence information, we know if a + // tensor result folds onto an operand. These are the only tensor value + // results that are supported at the moment. + // + // For tensors return values that do not fold onto an operand, additional + // work is needed (TODO) to either: + // * hoist a result into an inplaceable operand or + // * devise a better representation to truly return a buffer. + // + // Note: If a function has no body, no equivalence information is + // available. Consequently, a tensor return value cannot be proven to fold + // onto a FuncOp bbArg, so calls to such functions are not bufferizable at + // the moment. + + // 1. Compute the result types of the new CallOp. Tensor results that are + // equivalent to a FuncOp bbArg are no longer returned. + for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { + unsigned returnValIdx = it.index(); + Type returnType = it.value(); + if (!returnType.isa()) { + // Non-tensor values are returned. + retValMapping[returnValIdx] = resultTypes.size(); + resultTypes.push_back(returnType); + continue; + } + + if (Optional bbArgIdx = + getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) { + // Return operands that are equivalent to some bbArg, are not + // returned. + FailureOr bufferOrFailure = + state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); + if (failed(bufferOrFailure)) + return failure(); + replacementValues[returnValIdx] = *bufferOrFailure; + newOperands[*bbArgIdx] = *bufferOrFailure; + continue; + } + + if (!options.allowReturnAllocs) + return callOp->emitError( + "call to FuncOp that returns non-equivalent tensors not supported"); + + // Returning a memref. This memref is not equivalent to any bbArg. It is + // likely a newly allocated buffer. We may want to hoist such allocations + // to the call site in the future. + retValMapping[returnValIdx] = resultTypes.size(); + resultTypes.push_back(funcOp.getType().getResult(resultTypes.size())); + } + + // 2. Compute bufferized FunctionType. + FunctionType bufferizedFuncType = funcOp.getType(); + + // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. + for (OpOperand &opOperand : callOp->getOpOperands()) { + unsigned idx = opOperand.getOperandNumber(); + Value tensorOperand = opOperand.get(); + + // Non-tensor operands are just copied. + if (!tensorOperand.getType().isa()) { + newOperands[idx] = tensorOperand; + continue; + } + + // Retrieve buffers for tensor operands. Tensor operand buffers, who's + // corresponding FuncOp bbArgs are equivalent to a returned tensor, were + // already stored in `newOperands` during Step 1. + Value buffer = newOperands[idx]; + if (!buffer) { + FailureOr bufferOrFailure = state.getBuffer(rewriter, opOperand); + if (failed(bufferOrFailure)) + return failure(); + buffer = *bufferOrFailure; + } + + // Caller / callee type mismatch is handled with a CastOp. + auto memRefType = bufferizedFuncType.getInput(idx); + // Since we don't yet have a clear layout story, to_memref may + // conservatively turn tensors into more dynamic memref than necessary. + // If the memref type of the callee fails, introduce an extra memref.cast + // that will either canonicalize away or fail compilation until we can do + // something better. + if (buffer.getType() != memRefType) { + assert( + memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && + "CallOp::bufferize: cast incompatible"); + Value castBuffer = rewriter.create(callOp.getLoc(), + memRefType, buffer); + buffer = castBuffer; + } + newOperands[idx] = buffer; + } + + // 4. Create the new CallOp. + Operation *newCallOp = rewriter.create( + callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); + newCallOp->setAttrs(callOp->getAttrs()); + // Get replacement values for non-tensor / non-equivalent results. + for (unsigned i = 0; i < replacementValues.size(); ++i) { + if (replacementValues[i]) + continue; + replacementValues[i] = newCallOp->getResult(*retValMapping[i]); + } + + // 5. Replace the old op with the new op. + replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); + + return success(); + } +}; + +struct ReturnOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { +#ifndef NDEBUG + auto returnOp = cast(op); + assert(isa(returnOp->getParentOp()) && + "only support FuncOp parent for ReturnOp"); +#endif // NDEBUG + + // ReturnOps are bufferized as part of FuncOps. + return failure(); + } +}; + +struct FuncOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + // Rewrite function bbArgs and return values into buffer form (using the + // canonical memref layout for now). + // + // This relies on a buffer equivalence analysis of each return operand. When + // a result buffer is equivalent to a function bbArg, it is dropped from the + // return values and becomes inplaceable at all callers. + // + // All function bbArgs are writable unless they are explicitly marked as + // read-only. Callers must insert copies when needed. + // + // Note: Returning a memref is possible, but corresponding CallOp + // bufferizations fail unless `allowReturnAllocs`. + auto funcOp = cast(op); + const FuncAnalysisState &moduleState = + getFuncAnalysisState(state.getAnalysisState()); + const BufferizationOptions &options = state.getOptions(); + + // Construct the bufferized function type. + SmallVector argTypes; + for (const auto &it : llvm::enumerate(funcOp.getType().getInputs())) { + Type argType = it.value(); + if (auto tensorType = argType.dyn_cast()) { + argTypes.push_back( + getBufferizedFunctionArgType(funcOp, it.index(), options)); + continue; + } + argTypes.push_back(argType); + } + + // Bodiless functions are assumed opaque and we cannot know the + // bufferization contract they want to enforce. As a consequence, only + // support functions that don't return any tensors atm. + if (funcOp.getBody().empty()) { + FunctionType funcType = funcOp.getType(); + SmallVector retTypes; + for (Type resultType : funcType.getResults()) { + if (resultType.isa()) + return funcOp->emitError() << "cannot bufferize bodiless function " + << "that returns a tensor"; + retTypes.push_back(resultType); + } + funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); + return success(); + } + + // TODO: Support functions with multiple returns. + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. + Block &frontBlock = funcOp.getBody().front(); + for (BlockArgument &bbArg : frontBlock.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast(); + // Non-tensor types stay the same. + if (!tensorType) + continue; + + // Collect all uses of the bbArg. + SmallVector bbArgUses; + for (OpOperand &use : bbArg.getUses()) + bbArgUses.push_back(&use); + + // Change the bbArg type to memref. + Type memrefType = + getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + bbArg.setType(memrefType); + + // Replace all uses of the original tensor bbArg. + rewriter.setInsertionPointToStart(&frontBlock); + if (!bbArgUses.empty()) { + // Insert to_tensor because the remaining function body has not been + // bufferized yet. + Value toTensorOp = + rewriter.create(funcOp.getLoc(), bbArg); + for (OpOperand *use : bbArgUses) + use->set(toTensorOp); + } + } + + // 2. For each result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + + // If not a tensor type just forward it. + if (!returnVal.getType().isa()) { + returnValues.push_back(returnVal); + continue; + } + + // If return operand is equivalent to some bbArg, no need to return it. + if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( + funcOp, moduleState, returnOperand.getOperandNumber())) { + rewriter.setInsertionPoint(returnOp); + Location loc = returnOp.getLoc(); + Value toMemrefOp = rewriter.create( + loc, getMemRefType(returnVal.getType().cast(), options), + returnVal); + BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); + // Note: This copy will fold away. It must be inserted here to ensure + // that `returnVal` still has at least one use and does not fold away. + if (failed( + createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + return funcOp->emitError("could not generate copy for bbArg"); + continue; + } + + // Cast values at the call site if necessary. + returnValues.push_back( + getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); + } + + // 3. Rewrite the terminator without the in-place bufferizable values. + returnOp.operandsMutable().assign(returnValues); + + // 4. Rewrite the FuncOp type to buffer form. + funcOp.setType(FunctionType::get(op->getContext(), argTypes, + ValueRange(returnValues).getTypes())); + + return success(); + } + + /// Return `true` if the given function argument is writable. + bool isWritable(Operation *op, Value value, + const AnalysisState &state) const { + auto funcOp = cast(op); + BlockArgument bbArg = value.dyn_cast(); + assert(bbArg && "expected BlockArgument"); + + // "bufferization.writable" overrides other writability decisions. This is + // currently used for testing only. + if (BoolAttr writable = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), BufferizationDialect::kWritableAttrName)) + return writable.getValue(); + + // All function arguments are writable by default. + return true; + } + + bool isAllocationHoistingBarrier(Operation *op) const { return true; } +}; + +} // namespace func_ext +} // namespace bufferization +} // namespace mlir + +void mlir::bufferization::func_ext:: + registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addOpInterface(); + registry.addOpInterface(); + registry.addOpInterface(); +} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -0,0 +1,529 @@ +//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Module Bufferization is an extension of One-Shot Bufferize that +// bufferizes function boundaries. It provides `BufferizableOpInterface` +// implementations for FuncOp, CallOp and ReturnOp. +// +// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. +// This function analyzes the given module and determines the order of analysis +// and bufferization: Functions that are called are processed before their +// respective callers. +// +// After analyzing a FuncOp, additional information about its bbArgs is +// gathered through PostAnalysisStepFns and stored in +// `FuncAnalysisState`. +// +// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs +// for +// each tensor return value (if any). +// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is +// read/written. +// +// Only tensors that are equivalent to some FuncOp bbArg may be returned. +// Bufferization currently fails if other tensors (in particular tensors that +// bufferize out-of-place and result in a new buffer allocation) are returned. +// In the future, such allocations could be hoisted to the caller. +// +// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. +// ``` +// func @foo() -> tensor { +// %0 = linalg.init_tensor [...] : tensor +// return %0 : tensor +// } +// ``` +// +// Module Bufferization implements the following calling convention. +// +// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always +// be written to in-place. +// * If a tensor operand of a CallOp is read after the CallOp, the operand of +// the CallOp must bufferize out-of-place. +// +// Example: The tensor.insert op bufferizes in-place because it is allowed to +// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize +// out-of-place because `%t0` is modified by the callee but read by the +// tensor.extract op. The analysis of CallOps decides whether an OpOperand must +// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. +// ``` +// func @callee(%t1 : tensor) -> tensor { +// %f = ... : f32 +// %0 = tensor.insert %f into %t1[...] : tensor +// return %0 : tensor +// } +// +// func @caller() -> () { +// %t0 = ... : tensor +// %1 = call @callee(%t0) : (tensor) -> (tensor) +// %2 = tensor.extract %1[...] : tensor +// } +// ``` +// +// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot +// analyze the function body. In such a case, the CallOp analysis conservatively +// assumes that each tensor OpOperand is both read and written. +// +// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked +// as "not reading" and/or "not writing". + +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::bufferization::func_ext; + +/// A mapping of FuncOps to their callers. +using FuncCallerMap = DenseMap>; + +/// Get FuncAnalysisState. +static const FuncAnalysisState & +getFuncAnalysisState(const AnalysisState &state) { + Optional maybeState = + state.getDialectState( + func::FuncDialect::getDialectNamespace()); + assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); + return **maybeState; +} + +/// Get or create FuncAnalysisState. +static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { + return state.getOrCreateDialectState( + func::FuncDialect::getDialectNamespace()); +} + +/// Return the state (phase) of analysis of the FuncOp. +static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, + FuncOp funcOp) { + const FuncAnalysisState &funcState = getFuncAnalysisState(state); + auto it = funcState.analyzedFuncOps.find(funcOp); + if (it == funcState.analyzedFuncOps.end()) + return FuncOpAnalysisState::NotAnalyzed; + return it->second; +} + +/// Return the unique ReturnOp that terminates `funcOp`. +/// Return nullptr if there is no such unique ReturnOp. +static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { + func::ReturnOp returnOp; + for (Block &b : funcOp.getBody()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + +namespace { + +/// Annotate IR with the results of the analysis. For testing purposes only. +static void annotateEquivalentReturnBbArg(OpOperand &returnVal, + BlockArgument bbArg) { + const char *kEquivalentArgsAttr = "__equivalent_func_args__"; + Operation *op = returnVal.getOwner(); + + SmallVector equivBbArgs; + if (op->hasAttr(kEquivalentArgsAttr)) { + auto attr = op->getAttr(kEquivalentArgsAttr).cast(); + equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { + return a.cast().getValue().getSExtValue(); + })); + } else { + equivBbArgs.append(op->getNumOperands(), -1); + } + equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); + + OpBuilder b(op->getContext()); + op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); +} + +/// Store function BlockArguments that are equivalent to/aliasing a returned +/// value in FuncAnalysisState. +static LogicalResult +aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) { + FuncAnalysisState &funcState = getFuncAnalysisState(state); + + // Support only single return-terminated block in the function. + auto funcOp = cast(op); + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + // Initialize data structure. + auto createdEquiv = funcState.equivalentFuncArgs.try_emplace( + funcOp, DenseMap()); + auto createdAliasingOperands = funcState.aliasingFuncArgs.try_emplace( + funcOp, DenseMap>()); + auto createdAliasingResults = funcState.aliasingReturnVals.try_emplace( + funcOp, DenseMap>()); + (void)createdEquiv; + (void)createdAliasingOperands; + (void)createdAliasingResults; +#ifndef NDEBUG + assert(createdEquiv.second && "equivalence info exists already"); + assert(createdAliasingOperands.second && "aliasing info exists already"); + assert(createdAliasingResults.second && "aliasing info exists already"); +#endif // NDEBUG + + for (OpOperand &returnVal : returnOp->getOpOperands()) + if (returnVal.get().getType().isa()) + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) { + int64_t returnIdx = returnVal.getOperandNumber(); + int64_t bbArgIdx = bbArg.getArgNumber(); + if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { + funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; + if (state.getOptions().testAnalysisOnly) + annotateEquivalentReturnBbArg(returnVal, bbArg); + } + if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { + funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); + funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); + } + } + + return success(); +} + +/// Return true if the buffer of the given tensor value is written to. Must not +/// be called for values inside not yet analyzed functions. (Post-analysis +/// steps do not have to be run yet, i.e., "in progress" is also OK.) +static bool isValueWritten(Value value, const AnalysisState &state, + const BufferizationAliasInfo &aliasInfo) { +#ifndef NDEBUG + assert(value.getType().isa() && "expected TensorType"); + FuncOp funcOp; + if (auto bbArg = value.dyn_cast()) { + Operation *owner = bbArg.getOwner()->getParentOp(); + funcOp = isa(owner) ? cast(owner) + : owner->getParentOfType(); + } else { + funcOp = value.getDefiningOp()->getParentOfType(); + } + assert(getFuncOpAnalysisState(state, funcOp) != + FuncOpAnalysisState::NotAnalyzed && + "FuncOp must be fully analyzed or analysis in progress"); +#endif // NDEBUG + + bool isWritten = false; + aliasInfo.applyOnAliases(value, [&](Value val) { + for (OpOperand &use : val.getUses()) + if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) + isWritten = true; + }); + return isWritten; +} + +static void annotateFuncArgAccess(FuncOp funcOp, BlockArgument bbArg, + bool isRead, bool isWritten) { + OpBuilder b(funcOp.getContext()); + Attribute accessType; + if (isRead && isWritten) { + accessType = b.getStringAttr("read-write"); + } else if (isRead) { + accessType = b.getStringAttr("read"); + } else if (isWritten) { + accessType = b.getStringAttr("write"); + } else { + accessType = b.getStringAttr("none"); + } + funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); +} + +/// Determine which FuncOp bbArgs are read and which are written. If this +/// PostAnalysisStepFn is run on a function with unknown ops, it will +/// conservatively assume that such ops bufferize to a read + write. +static LogicalResult +funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) { + FuncAnalysisState &funcState = getFuncAnalysisState(state); + auto funcOp = cast(op); + + // Initialize data structure. + auto createdRead = + funcState.readBbArgs.try_emplace(funcOp, DenseSet()); + auto createdWritten = + funcState.writtenBbArgs.try_emplace(funcOp, DenseSet()); + (void)createdRead; + (void)createdWritten; +#ifndef NDEBUG + assert(createdRead.second && "bbarg access info exists already"); + assert(createdWritten.second && "bbarg access info exists already"); +#endif // NDEBUG + + // If the function has no body, conservatively assume that all args are + // read + written. + if (funcOp.getBody().empty()) { + for (BlockArgument bbArg : funcOp.getArguments()) { + funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); + } + + return success(); + } + + for (BlockArgument bbArg : funcOp.getArguments()) { + if (!bbArg.getType().isa()) + continue; + bool isRead = state.isValueRead(bbArg); + bool isWritten = isValueWritten(bbArg, state, aliasInfo); + if (state.getOptions().testAnalysisOnly) + annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); + if (isRead) + funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + if (isWritten) + funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); + } + + return success(); +} +} // namespace + +/// Remove bufferization attributes on FuncOp arguments. +static void removeBufferizationAttributes(BlockArgument bbArg) { + auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.removeArgAttr(bbArg.getArgNumber(), + BufferizationDialect::kBufferLayoutAttrName); + funcOp.removeArgAttr(bbArg.getArgNumber(), + BufferizationDialect::kWritableAttrName); +} + +/// Return the FuncOp called by `callOp`. +static FuncOp getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + +/// Gather equivalence info of CallOps. +/// Note: This only adds new equivalence info if the called function was already +/// analyzed. +// TODO: This does not handle cyclic function call graphs etc. +static void equivalenceAnalysis(FuncOp funcOp, + BufferizationAliasInfo &aliasInfo, + FuncAnalysisState &funcState) { + funcOp->walk([&](func::CallOp callOp) { + FuncOp calledFunction = getCalledFunction(callOp); + assert(calledFunction && "could not retrieved called FuncOp"); + + // No equivalence info available for the called function. + if (!funcState.equivalentFuncArgs.count(calledFunction)) + return WalkResult::skip(); + + for (auto it : funcState.equivalentFuncArgs[calledFunction]) { + int64_t returnIdx = it.first; + int64_t bbargIdx = it.second; + Value returnVal = callOp.getResult(returnIdx); + Value argVal = callOp->getOperand(bbargIdx); + aliasInfo.unionEquivalenceClasses(returnVal, argVal); + } + + return WalkResult::advance(); + }); +} + +/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by +/// callee-caller order (i.e. callees without callers first). +/// Store the map of FuncOp to all its callers in `callerMap`. +/// Return `failure()` if a cycle of calls is detected or if we are unable to +/// retrieve the called FuncOp from any CallOpInterface. +static LogicalResult +getFuncOpsOrderedByCalls(ModuleOp moduleOp, + SmallVectorImpl &orderedFuncOps, + FuncCallerMap &callerMap) { + // For each FuncOp, the set of functions called by it (i.e. the union of + // symbols of all nested CallOpInterfaceOp). + DenseMap> calledBy; + // For each FuncOp, the number of CallOpInterface it contains. + DenseMap numberCallOpsContainedInFuncOp; + WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult { + if (!funcOp.getBody().empty()) { + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() + << "cannot bufferize a FuncOp with tensors and " + "without a unique ReturnOp"; + } + + numberCallOpsContainedInFuncOp[funcOp] = 0; + return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { + // Only support CallOp for now. + if (!isa(callOp.getOperation())) + return callOp->emitError() << "expected a CallOp"; + FuncOp calledFunction = getCalledFunction(callOp); + assert(calledFunction && "could not retrieved called FuncOp"); + auto it = callerMap.try_emplace(calledFunction, DenseSet{}); + it.first->getSecond().insert(callOp); + if (calledBy[calledFunction].count(funcOp) == 0) { + calledBy[calledFunction].insert(funcOp); + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + }); + if (res.wasInterrupted()) + return failure(); + // Iteratively remove function operation that do not call any of the + // functions remaining in the callCounter map and add them to the worklist. + while (!numberCallOpsContainedInFuncOp.empty()) { + auto it = llvm::find_if(numberCallOpsContainedInFuncOp, + [](auto entry) { return entry.getSecond() == 0; }); + if (it == numberCallOpsContainedInFuncOp.end()) + return moduleOp.emitOpError( + "expected callgraph to be free of circular dependencies."); + orderedFuncOps.push_back(it->getFirst()); + for (auto callee : calledBy[it->getFirst()]) + numberCallOpsContainedInFuncOp[callee]--; + numberCallOpsContainedInFuncOp.erase(it); + } + return success(); +} + +static void foreachCaller(const FuncCallerMap &callerMap, FuncOp callee, + llvm::function_ref doit) { + auto itCallers = callerMap.find(callee); + if (itCallers == callerMap.end()) + return; + for (Operation *caller : itCallers->second) + doit(caller); +} + +/// Set the attribute that triggers inplace bufferization on a FuncOp argument +/// `bbArg`. +static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { + auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.setArgAttr(bbArg.getArgNumber(), + BufferizableOpInterface::kInplaceableAttrName, + BoolAttr::get(bbArg.getContext(), inPlace)); +} + +/// Annotate the IR with the result of the analysis. For testing/debugging only. +static void annotateOpsWithBufferizationMarkers(FuncOp funcOp, + const AnalysisState &state) { + auto bufferizableOp = cast(funcOp.getOperation()); + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); +} + +/// Fold return values that are memref casts. +static void foldMemRefCasts(FuncOp funcOp) { + if (funcOp.getBody().empty()) + return; + + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + SmallVector resultTypes; + + for (OpOperand &operand : returnOp->getOpOperands()) { + if (auto castOp = operand.get().getDefiningOp()) { + operand.set(castOp.source()); + resultTypes.push_back(castOp.source().getType()); + } else { + resultTypes.push_back(operand.get().getType()); + } + } + + auto newFuncType = FunctionType::get( + funcOp.getContext(), funcOp.getType().getInputs(), resultTypes); + funcOp.setType(newFuncType); +} + +LogicalResult mlir::bufferization::runOneShotModuleBufferize( + ModuleOp moduleOp, OneShotBufferizationOptions options) { + IRRewriter rewriter(moduleOp.getContext()); + OneShotAnalysisState analysisState(moduleOp, options); + BufferizationState bufferizationState(analysisState); + FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); + BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); + + // A list of functions in the order in which they are analyzed + bufferized. + SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; + + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) + return failure(); + + // Collect bbArg/return value information after the analysis. + options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); + options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); + + // Analyze ops. + for (FuncOp funcOp : orderedFuncOps) { + // No body => no analysis. + if (funcOp.getBody().empty()) + continue; + + // Now analyzing function. + funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; + + // Gather equivalence info for CallOps. + equivalenceAnalysis(funcOp, aliasInfo, funcState); + + // Analyze funcOp. + if (failed(analyzeOp(funcOp, analysisState))) + return failure(); + + // Mark op as fully analyzed. + funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; + + // Add annotations to function arguments. + if (options.testAnalysisOnly) + annotateOpsWithBufferizationMarkers(funcOp, analysisState); + } + + if (options.testAnalysisOnly) + return success(); + + // Bufferize functions. + for (FuncOp funcOp : orderedFuncOps) { + // Note: It would be good to apply cleanups here but we cannot as aliasInfo + // would be invalidated. + if (failed(bufferizeOp(funcOp, bufferizationState))) + return failure(); + foldMemRefCasts(funcOp); + } + + // Check result. + for (FuncOp funcOp : orderedFuncOps) { + if (!options.allowReturnAllocs && + llvm::any_of(funcOp.getType().getResults(), [](Type t) { + return t.isa(); + })) { + funcOp->emitError("memref return type is unsupported"); + return failure(); + } + } + + // Finalize all buffers. + if (failed(finalizeBuffers(moduleOp, options))) + return failure(); + + // Post-pass cleanup of function argument attributes. + moduleOp.walk([&](FuncOp op) { + for (BlockArgument bbArg : op.getArguments()) + removeBufferizationAttributes(bbArg); + }); + + return success(); +} diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_mlir_dialect_library(MLIRModuleBufferization - ModuleBufferization.cpp - - LINK_LIBS PUBLIC - MLIRBufferization - MLIRBufferizationTransforms - MLIRFunc - MLIRFuncTransforms - MLIRIR - MLIRMemRef -) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ /dev/null @@ -1,1045 +0,0 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Module Bufferization is an extension of One-Shot Bufferize that -// bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp. -// -// Module Bufferization is run via `runModuleBufferize(ModuleOp, ...)`. This -// function analyzes the given module and determines the order of analysis and -// bufferization: Functions that are called are processed before their -// respective callers. -// -// After analyzing a FuncOp, additional information about its bbArgs is -// gathered through PostAnalysisStepFns and stored in -// `FuncAnalysisState`. -// -// * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs -// for -// each tensor return value (if any). -// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is -// read/written. -// -// Only tensors that are equivalent to some FuncOp bbArg may be returned. -// Bufferization currently fails if other tensors (in particular tensors that -// bufferize out-of-place and result in a new buffer allocation) are returned. -// In the future, such allocations could be hoisted to the caller. -// -// Example: `foo` fails bufferization because %0 is not equivalent to any bbArg. -// ``` -// func @foo() -> tensor { -// %0 = linalg.init_tensor [...] : tensor -// return %0 : tensor -// } -// ``` -// -// Module Bufferization implements the following calling convention. -// -// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always -// be written to in-place. -// * If a tensor operand of a CallOp is read after the CallOp, the operand of -// the CallOp must bufferize out-of-place. -// -// Example: The tensor.insert op bufferizes in-place because it is allowed to -// modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize -// out-of-place because `%t0` is modified by the callee but read by the -// tensor.extract op. The analysis of CallOps decides whether an OpOperand must -// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`. -// ``` -// func @callee(%t1 : tensor) -> tensor { -// %f = ... : f32 -// %0 = tensor.insert %f into %t1[...] : tensor -// return %0 : tensor -// } -// -// func @caller() -> () { -// %t0 = ... : tensor -// %1 = call @callee(%t0) : (tensor) -> (tensor) -// %2 = tensor.extract %1[...] : tensor -// } -// ``` -// -// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot -// analyze the function body. In such a case, the CallOp analysis conservatively -// assumes that each tensor OpOperand is both read and written. -// -// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked -// as "not reading" and/or "not writing". - -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" - -#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" -#include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/IR/Operation.h" - -using namespace mlir; -using namespace linalg; -using namespace tensor; -using namespace comprehensive_bufferize; -using namespace mlir::bufferization; - -/// A mapping of FuncOps to their callers. -using FuncCallerMap = DenseMap>; - -namespace { -/// The state of analysis of a FuncOp. -enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed }; - -/// Extra analysis state that is required for bufferization of function -/// boundaries. -struct FuncAnalysisState : public DialectAnalysisState { - /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg - /// indices. - DenseMap> equivalentFuncArgs; - - /// A mapping of ReturnOp OpOperand indices to aliasing FuncOp BBArg indices. - DenseMap>> aliasingFuncArgs; - - /// A mapping of FuncOp BBArg indices to aliasing ReturnOp OpOperand indices. - DenseMap>> aliasingReturnVals; - - /// A set of all read BlockArguments of FuncOps. - DenseMap> readBbArgs; - - /// A set of all written-to BlockArguments of FuncOps. - DenseMap> writtenBbArgs; - - /// Keep track of which FuncOps are fully analyzed or currently being - /// analyzed. - DenseMap analyzedFuncOps; -}; -} // namespace - -/// Get FuncAnalysisState. -static const FuncAnalysisState & -getFuncAnalysisState(const AnalysisState &state) { - Optional maybeState = - state.getDialectState( - func::FuncDialect::getDialectNamespace()); - assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); - return **maybeState; -} - -/// Get or create FuncAnalysisState. -static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { - return state.getOrCreateDialectState( - func::FuncDialect::getDialectNamespace()); -} - -/// Return the state (phase) of analysis of the FuncOp. -static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, - FuncOp funcOp) { - const FuncAnalysisState &moduleState = getFuncAnalysisState(state); - auto it = moduleState.analyzedFuncOps.find(funcOp); - if (it == moduleState.analyzedFuncOps.end()) - return FuncOpAnalysisState::NotAnalyzed; - return it->second; -} - -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static func::ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { - func::ReturnOp returnOp; - for (Block &b : funcOp.getBody()) { - if (auto candidateOp = dyn_cast(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; -} - -namespace { - -/// Annotate IR with the results of the analysis. For testing purposes only. -static void annotateEquivalentReturnBbArg(OpOperand &returnVal, - BlockArgument bbArg) { - const char *kEquivalentArgsAttr = "__equivalent_func_args__"; - Operation *op = returnVal.getOwner(); - - SmallVector equivBbArgs; - if (op->hasAttr(kEquivalentArgsAttr)) { - auto attr = op->getAttr(kEquivalentArgsAttr).cast(); - equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { - return a.cast().getValue().getSExtValue(); - })); - } else { - equivBbArgs.append(op->getNumOperands(), -1); - } - equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); - - OpBuilder b(op->getContext()); - op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); -} - -/// Store function BlockArguments that are equivalent to/aliasing a returned -/// value in FuncAnalysisState. -static LogicalResult -aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { - FuncAnalysisState &moduleState = getFuncAnalysisState(state); - - // Support only single return-terminated block in the function. - auto funcOp = cast(op); - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // Initialize data structure. - auto createdEquiv = moduleState.equivalentFuncArgs.try_emplace( - funcOp, DenseMap()); - auto createdAliasingOperands = moduleState.aliasingFuncArgs.try_emplace( - funcOp, DenseMap>()); - auto createdAliasingResults = moduleState.aliasingReturnVals.try_emplace( - funcOp, DenseMap>()); - (void)createdEquiv; - (void)createdAliasingOperands; - (void)createdAliasingResults; -#ifndef NDEBUG - assert(createdEquiv.second && "equivalence info exists already"); - assert(createdAliasingOperands.second && "aliasing info exists already"); - assert(createdAliasingResults.second && "aliasing info exists already"); -#endif // NDEBUG - - for (OpOperand &returnVal : returnOp->getOpOperands()) - if (returnVal.get().getType().isa()) - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) { - int64_t returnIdx = returnVal.getOperandNumber(); - int64_t bbArgIdx = bbArg.getArgNumber(); - if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { - moduleState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; - if (state.getOptions().testAnalysisOnly) - annotateEquivalentReturnBbArg(returnVal, bbArg); - } - if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { - moduleState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); - moduleState.aliasingReturnVals[funcOp][bbArgIdx].push_back( - returnIdx); - } - } - - return success(); -} - -/// Return true if the buffer of the given tensor value is written to. Must not -/// be called for values inside not yet analyzed functions. (Post-analysis -/// steps do not have to be run yet, i.e., "in progress" is also OK.) -static bool isValueWritten(Value value, const AnalysisState &state, - const BufferizationAliasInfo &aliasInfo) { -#ifndef NDEBUG - assert(value.getType().isa() && "expected TensorType"); - FuncOp funcOp; - if (auto bbArg = value.dyn_cast()) { - Operation *owner = bbArg.getOwner()->getParentOp(); - funcOp = isa(owner) ? cast(owner) - : owner->getParentOfType(); - } else { - funcOp = value.getDefiningOp()->getParentOfType(); - } - assert(getFuncOpAnalysisState(state, funcOp) != - FuncOpAnalysisState::NotAnalyzed && - "FuncOp must be fully analyzed or analysis in progress"); -#endif // NDEBUG - - bool isWritten = false; - aliasInfo.applyOnAliases(value, [&](Value val) { - for (OpOperand &use : val.getUses()) - if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) - isWritten = true; - }); - return isWritten; -} - -static void annotateFuncArgAccess(FuncOp funcOp, BlockArgument bbArg, - bool isRead, bool isWritten) { - OpBuilder b(funcOp.getContext()); - Attribute accessType; - if (isRead && isWritten) { - accessType = b.getStringAttr("read-write"); - } else if (isRead) { - accessType = b.getStringAttr("read"); - } else if (isWritten) { - accessType = b.getStringAttr("write"); - } else { - accessType = b.getStringAttr("none"); - } - funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); -} - -/// Determine which FuncOp bbArgs are read and which are written. If this -/// PostAnalysisStepFn is run on a function with unknown ops, it will -/// conservatively assume that such ops bufferize to a read + write. -static LogicalResult -funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { - FuncAnalysisState &moduleState = getFuncAnalysisState(state); - auto funcOp = cast(op); - - // Initialize data structure. - auto createdRead = - moduleState.readBbArgs.try_emplace(funcOp, DenseSet()); - auto createdWritten = - moduleState.writtenBbArgs.try_emplace(funcOp, DenseSet()); - (void)createdRead; - (void)createdWritten; -#ifndef NDEBUG - assert(createdRead.second && "bbarg access info exists already"); - assert(createdWritten.second && "bbarg access info exists already"); -#endif // NDEBUG - - // If the function has no body, conservatively assume that all args are - // read + written. - if (funcOp.getBody().empty()) { - for (BlockArgument bbArg : funcOp.getArguments()) { - moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); - moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); - } - - return success(); - } - - for (BlockArgument bbArg : funcOp.getArguments()) { - if (!bbArg.getType().isa()) - continue; - bool isRead = state.isValueRead(bbArg); - bool isWritten = isValueWritten(bbArg, state, aliasInfo); - if (state.getOptions().testAnalysisOnly) - annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); - if (isRead) - moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); - if (isWritten) - moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); - } - - return success(); -} -} // namespace - -/// If `value` is a memref::CastOp, return its source. Otherwise, return -/// `value` directly. -static Value getNonCastedValue(Value value) { - while (auto castOp = value.getDefiningOp()) - value = castOp.source(); - return value; -} - -/// Remove the attribute that triggers inplace bufferization on a FuncOp -/// argument `bbArg`. -static void removeBufferizationFuncArguments(BlockArgument bbArg) { - auto funcOp = cast(bbArg.getOwner()->getParentOp()); - funcOp.removeArgAttr(bbArg.getArgNumber(), - BufferizableOpInterface::kBufferLayoutAttrName); - funcOp.removeArgAttr(bbArg.getArgNumber(), - BufferizableOpInterface::kInplaceableAttrName); -} - -/// Return the FuncOp called by `callOp`. -static FuncOp getCalledFunction(CallOpInterface callOp) { - SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); - if (!sym) - return nullptr; - return dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(callOp, sym)); -} - -/// Return the index-th bufferized function argument type. This assumes that the -/// specified argument is a tensor. If the tensor is ranked, a layout map may be -/// specified by the user. If no layout map is specified, a fully dynamic map is -/// used. -static BaseMemRefType -getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, - const BufferizationOptions &options) { - auto tensorType = funcOp.getType().getInput(index).dyn_cast(); - assert(tensorType && "expected TensorType"); - BaseMemRefType memrefType = getMemRefType(tensorType, options); - - auto layoutAttr = funcOp.getArgAttrOfType( - index, BufferizableOpInterface::kBufferLayoutAttrName); - if (!layoutAttr) - return memrefType; - - auto rankedMemrefType = memrefType.dyn_cast(); - assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); - return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), - layoutAttr.getValue(), - memrefType.getMemorySpaceAsInt()); -} - -/// Gather equivalence info of CallOps. -/// Note: This only adds new equivalence info if the called function was already -/// analyzed. -// TODO: This does not handle cyclic function call graphs etc. -static void equivalenceAnalysis(FuncOp funcOp, - BufferizationAliasInfo &aliasInfo, - FuncAnalysisState &moduleState) { - funcOp->walk([&](func::CallOp callOp) { - FuncOp calledFunction = getCalledFunction(callOp); - assert(calledFunction && "could not retrieved called FuncOp"); - - // No equivalence info available for the called function. - if (!moduleState.equivalentFuncArgs.count(calledFunction)) - return WalkResult::skip(); - - for (auto it : moduleState.equivalentFuncArgs[calledFunction]) { - int64_t returnIdx = it.first; - int64_t bbargIdx = it.second; - Value returnVal = callOp.getResult(returnIdx); - Value argVal = callOp->getOperand(bbargIdx); - aliasInfo.unionEquivalenceClasses(returnVal, argVal); - } - - return WalkResult::advance(); - }); -} - -/// Return the index of the bbArg in the given FuncOp that is equivalent to the -/// specified return value (if any). -static Optional getEquivalentFuncArgIdx(FuncOp funcOp, - const FuncAnalysisState &state, - int64_t returnValIdx) { - auto funcOpIt = state.equivalentFuncArgs.find(funcOp); - if (funcOpIt == state.equivalentFuncArgs.end()) - // No equivalence info stores for funcOp. - return None; - - auto retValIt = funcOpIt->getSecond().find(returnValIdx); - if (retValIt == funcOpIt->getSecond().end()) - // Return value has no equivalent bbArg. - return None; - - return retValIt->getSecond(); -} - -/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by -/// callee-caller order (i.e. callees without callers first). -/// Store the map of FuncOp to all its callers in `callerMap`. -/// Return `failure()` if a cycle of calls is detected or if we are unable to -/// retrieve the called FuncOp from any CallOpInterface. -static LogicalResult -getFuncOpsOrderedByCalls(ModuleOp moduleOp, - SmallVectorImpl &orderedFuncOps, - FuncCallerMap &callerMap) { - // For each FuncOp, the set of functions called by it (i.e. the union of - // symbols of all nested CallOpInterfaceOp). - DenseMap> calledBy; - // For each FuncOp, the number of CallOpInterface it contains. - DenseMap numberCallOpsContainedInFuncOp; - WalkResult res = moduleOp.walk([&](FuncOp funcOp) -> WalkResult { - if (!funcOp.getBody().empty()) { - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - if (!returnOp) - return funcOp->emitError() - << "cannot bufferize a FuncOp with tensors and " - "without a unique ReturnOp"; - } - - numberCallOpsContainedInFuncOp[funcOp] = 0; - return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { - // Only support CallOp for now. - if (!isa(callOp.getOperation())) - return callOp->emitError() << "expected a CallOp"; - FuncOp calledFunction = getCalledFunction(callOp); - assert(calledFunction && "could not retrieved called FuncOp"); - auto it = callerMap.try_emplace(calledFunction, DenseSet{}); - it.first->getSecond().insert(callOp); - if (calledBy[calledFunction].count(funcOp) == 0) { - calledBy[calledFunction].insert(funcOp); - numberCallOpsContainedInFuncOp[funcOp]++; - } - return WalkResult::advance(); - }); - }); - if (res.wasInterrupted()) - return failure(); - // Iteratively remove function operation that do not call any of the - // functions remaining in the callCounter map and add them to the worklist. - while (!numberCallOpsContainedInFuncOp.empty()) { - auto it = llvm::find_if(numberCallOpsContainedInFuncOp, - [](auto entry) { return entry.getSecond() == 0; }); - if (it == numberCallOpsContainedInFuncOp.end()) - return moduleOp.emitOpError( - "expected callgraph to be free of circular dependencies."); - orderedFuncOps.push_back(it->getFirst()); - for (auto callee : calledBy[it->getFirst()]) - numberCallOpsContainedInFuncOp[callee]--; - numberCallOpsContainedInFuncOp.erase(it); - } - return success(); -} - -static void foreachCaller(const FuncCallerMap &callerMap, FuncOp callee, - llvm::function_ref doit) { - auto itCallers = callerMap.find(callee); - if (itCallers == callerMap.end()) - return; - for (Operation *caller : itCallers->second) - doit(caller); -} - -namespace mlir { -namespace linalg { -namespace comprehensive_bufferize { -namespace std_ext { - -struct CallOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - - const FuncAnalysisState &moduleState = getFuncAnalysisState(state); - if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) - // FuncOp not analyzed yet. Assume that OpOperand is read. - return true; - - auto it = moduleState.readBbArgs.find(funcOp); - assert(it != moduleState.readBbArgs.end() && - "expected analysis info for analyzed FuncOps"); - return it->second.contains(opOperand.getOperandNumber()); - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - - const FuncAnalysisState &moduleState = getFuncAnalysisState(state); - if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) - // FuncOp not analyzed yet. Assume that OpOperand is written. - return true; - - auto it = moduleState.writtenBbArgs.find(funcOp); - assert(it != moduleState.writtenBbArgs.end() && - "expected analysis info for analyzed FuncOps"); - return it->second.contains(opOperand.getOperandNumber()); - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &moduleState = getFuncAnalysisState(state); - if (getFuncOpAnalysisState(state, funcOp) != - FuncOpAnalysisState::Analyzed) { - // FuncOp not analyzed yet. Any OpResult may be aliasing. - SmallVector result; - for (OpResult opResult : op->getOpResults()) - if (opResult.getType().isa()) - result.push_back(opResult); - return result; - } - - // Get aliasing results from state. - auto mapping = moduleState.aliasingReturnVals.find(funcOp); - assert(mapping != moduleState.aliasingReturnVals.end() && - "expected analysis info analyzed FuncOps"); - auto aliasingReturnVals = - mapping->second.find(opOperand.getOperandNumber()); - if (aliasingReturnVals == mapping->second.end()) - return {}; - - SmallVector result; - for (int64_t resultIdx : aliasingReturnVals->second) - result.push_back(callOp->getOpResult(resultIdx)); - return result; - } - - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { - func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &moduleState = getFuncAnalysisState(state); - if (getFuncOpAnalysisState(state, funcOp) != - FuncOpAnalysisState::Analyzed) { - // FuncOp not analyzed yet. Any OpOperand may be aliasing. - SmallVector result; - for (OpOperand &opOperand : op->getOpOperands()) - if (opOperand.get().getType().isa()) - result.push_back(&opOperand); - return result; - } - - // Get aliasing bbArgs from state. - auto mapping = moduleState.aliasingFuncArgs.find(funcOp); - assert(mapping != moduleState.aliasingFuncArgs.end() && - "expected analysis info analyzed FuncOps"); - auto aliasingFuncArgs = mapping->second.find(opResult.getResultNumber()); - if (aliasingFuncArgs == mapping->second.end()) - return {}; - - SmallVector result; - for (int64_t bbArgIdx : aliasingFuncArgs->second) - result.push_back(&callOp->getOpOperand(bbArgIdx)); - return result; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } - - /// All function arguments are writable. It is the responsibility of the - /// CallOp to insert buffer copies where necessary. - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { - func::CallOp callOp = cast(op); - unsigned numResults = callOp.getNumResults(); - unsigned numOperands = callOp->getNumOperands(); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &moduleState = - getFuncAnalysisState(state.getAnalysisState()); - const OneShotBufferizationOptions &options = - static_cast(state.getOptions()); - - // Result types of the bufferized CallOp. - SmallVector resultTypes; - // Replacement values for the existing CallOp. These are usually the results - // of the bufferized CallOp, unless a tensor result folds onto an operand. - SmallVector replacementValues(numResults, Value()); - // For non-tensor results: A mapping from return val indices of the old - // CallOp to return val indices of the bufferized CallOp. - SmallVector> retValMapping(numResults, None); - // Operands of the bufferized CallOp. - SmallVector newOperands(numOperands, Value()); - - // Based on previously gathered equivalence information, we know if a - // tensor result folds onto an operand. These are the only tensor value - // results that are supported at the moment. - // - // For tensors return values that do not fold onto an operand, additional - // work is needed (TODO) to either: - // * hoist a result into an inplaceable operand or - // * devise a better representation to truly return a buffer. - // - // Note: If a function has no body, no equivalence information is - // available. Consequently, a tensor return value cannot be proven to fold - // onto a FuncOp bbArg, so calls to such functions are not bufferizable at - // the moment. - - // 1. Compute the result types of the new CallOp. Tensor results that are - // equivalent to a FuncOp bbArg are no longer returned. - for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { - unsigned returnValIdx = it.index(); - Type returnType = it.value(); - if (!returnType.isa()) { - // Non-tensor values are returned. - retValMapping[returnValIdx] = resultTypes.size(); - resultTypes.push_back(returnType); - continue; - } - - if (Optional bbArgIdx = - getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) { - // Return operands that are equivalent to some bbArg, are not - // returned. - FailureOr bufferOrFailure = - state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx)); - if (failed(bufferOrFailure)) - return failure(); - replacementValues[returnValIdx] = *bufferOrFailure; - newOperands[*bbArgIdx] = *bufferOrFailure; - continue; - } - - if (!options.allowReturnAllocs) - return callOp->emitError( - "call to FuncOp that returns non-equivalent tensors not supported"); - - // Returning a memref. This memref is not equivalent to any bbArg. It is - // likely a newly allocated buffer. We may want to hoist such allocations - // to the call site in the future. - retValMapping[returnValIdx] = resultTypes.size(); - resultTypes.push_back(funcOp.getType().getResult(resultTypes.size())); - } - - // 2. Compute bufferized FunctionType. - FunctionType bufferizedFuncType = funcOp.getType(); - - // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. - for (OpOperand &opOperand : callOp->getOpOperands()) { - unsigned idx = opOperand.getOperandNumber(); - Value tensorOperand = opOperand.get(); - - // Non-tensor operands are just copied. - if (!tensorOperand.getType().isa()) { - newOperands[idx] = tensorOperand; - continue; - } - - // Retrieve buffers for tensor operands. Tensor operand buffers, who's - // corresponding FuncOp bbArgs are equivalent to a returned tensor, were - // already stored in `newOperands` during Step 1. - Value buffer = newOperands[idx]; - if (!buffer) { - FailureOr bufferOrFailure = state.getBuffer(rewriter, opOperand); - if (failed(bufferOrFailure)) - return failure(); - buffer = *bufferOrFailure; - } - - // Caller / callee type mismatch is handled with a CastOp. - auto memRefType = bufferizedFuncType.getInput(idx); - // Since we don't yet have a clear layout story, to_memref may - // conservatively turn tensors into more dynamic memref than necessary. - // If the memref type of the callee fails, introduce an extra memref.cast - // that will either canonicalize away or fail compilation until we can do - // something better. - if (buffer.getType() != memRefType) { - assert( - memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && - "CallOp::bufferize: cast incompatible"); - Value castBuffer = rewriter.create(callOp.getLoc(), - memRefType, buffer); - buffer = castBuffer; - } - newOperands[idx] = buffer; - } - - // 4. Create the new CallOp. - Operation *newCallOp = rewriter.create( - callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); - newCallOp->setAttrs(callOp->getAttrs()); - // Get replacement values for non-tensor / non-equivalent results. - for (unsigned i = 0; i < replacementValues.size(); ++i) { - if (replacementValues[i]) - continue; - replacementValues[i] = newCallOp->getResult(*retValMapping[i]); - } - - // 5. Replace the old op with the new op. - replaceOpWithBufferizedValues(rewriter, callOp, replacementValues); - - return success(); - } -}; - -struct ReturnOpInterface - : public BufferizableOpInterface::ExternalModel { - bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return true; - } - - bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return false; - } - - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, - const AnalysisState &state) const { - return {}; - } - - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { -#ifndef NDEBUG - auto returnOp = cast(op); - assert(isa(returnOp->getParentOp()) && - "only support FuncOp parent for ReturnOp"); -#endif // NDEBUG - - // ReturnOps are bufferized as part of FuncOps. - return failure(); - } -}; - -struct FuncOpInterface - : public BufferizableOpInterface::ExternalModel { - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, - BufferizationState &state) const { - // Rewrite function bbArgs and return values into buffer form (using the - // canonical memref layout for now). - // - // This relies on a buffer equivalence analysis of each return operand. When - // a result buffer is equivalent to a function bbArg, it is dropped from the - // return values and becomes inplaceable at all callers. - // - // All function bbArgs are writable unless they are explicitly marked as - // read-only. Callers must insert copies when needed. - // - // Note: Returning a memref is possible, but corresponding CallOp - // bufferizations fail unless `allowReturnAllocs`. - auto funcOp = cast(op); - const FuncAnalysisState &moduleState = - getFuncAnalysisState(state.getAnalysisState()); - const BufferizationOptions &options = state.getOptions(); - - // Construct the bufferized function type. - SmallVector argTypes; - for (const auto &it : llvm::enumerate(funcOp.getType().getInputs())) { - Type argType = it.value(); - if (auto tensorType = argType.dyn_cast()) { - argTypes.push_back( - getBufferizedFunctionArgType(funcOp, it.index(), options)); - continue; - } - argTypes.push_back(argType); - } - - // Bodiless functions are assumed opaque and we cannot know the - // bufferization contract they want to enforce. As a consequence, only - // support functions that don't return any tensors atm. - if (funcOp.getBody().empty()) { - FunctionType funcType = funcOp.getType(); - SmallVector retTypes; - for (Type resultType : funcType.getResults()) { - if (resultType.isa()) - return funcOp->emitError() << "cannot bufferize bodiless function " - << "that returns a tensor"; - retTypes.push_back(resultType); - } - funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); - return success(); - } - - // TODO: Support functions with multiple returns. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. - Block &frontBlock = funcOp.getBody().front(); - for (BlockArgument &bbArg : frontBlock.getArguments()) { - auto tensorType = bbArg.getType().dyn_cast(); - // Non-tensor types stay the same. - if (!tensorType) - continue; - - // Collect all uses of the bbArg. - SmallVector bbArgUses; - for (OpOperand &use : bbArg.getUses()) - bbArgUses.push_back(&use); - - // Change the bbArg type to memref. - Type memrefType = - getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); - bbArg.setType(memrefType); - - // Replace all uses of the original tensor bbArg. - rewriter.setInsertionPointToStart(&frontBlock); - if (!bbArgUses.empty()) { - // Insert to_tensor because the remaining function body has not been - // bufferized yet. - Value toTensorOp = - rewriter.create(funcOp.getLoc(), bbArg); - for (OpOperand *use : bbArgUses) - use->set(toTensorOp); - } - } - - // 2. For each result, keep track of which inplace argument it reuses. - SmallVector returnValues; - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - Value returnVal = returnOperand.get(); - - // If not a tensor type just forward it. - if (!returnVal.getType().isa()) { - returnValues.push_back(returnVal); - continue; - } - - // If return operand is equivalent to some bbArg, no need to return it. - if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( - funcOp, moduleState, returnOperand.getOperandNumber())) { - rewriter.setInsertionPoint(returnOp); - Location loc = returnOp.getLoc(); - Value toMemrefOp = rewriter.create( - loc, getMemRefType(returnVal.getType().cast(), options), - returnVal); - BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); - // Note: This copy will fold away. It must be inserted here to ensure - // that `returnVal` still has at least one use and does not fold away. - if (failed( - createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) - return funcOp->emitError("could not generate copy for bbArg"); - continue; - } - - // Cast values at the call site if necessary. - returnValues.push_back( - getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); - } - - // 3. Rewrite the terminator without the in-place bufferizable values. - returnOp.operandsMutable().assign(returnValues); - - // 4. Rewrite the FuncOp type to buffer form. - funcOp.setType(FunctionType::get(op->getContext(), argTypes, - ValueRange(returnValues).getTypes())); - - return success(); - } - - /// Return `true` if the given function argument is writable. - bool isWritable(Operation *op, Value value, - const AnalysisState &state) const { - auto funcOp = cast(op); - BlockArgument bbArg = value.dyn_cast(); - assert(bbArg && "expected BlockArgument"); - - // "linalg.inplaceable" overrides other writability decisions. This is - // currently used for testing only. - if (BoolAttr inplaceAttr = funcOp.getArgAttrOfType( - bbArg.getArgNumber(), - BufferizableOpInterface::kInplaceableAttrName)) - return inplaceAttr.getValue(); - - // All function arguments are writable by default. - return true; - } - - bool isAllocationHoistingBarrier(Operation *op) const { return true; } -}; - -} // namespace std_ext -} // namespace comprehensive_bufferize -} // namespace linalg -} // namespace mlir - -void mlir::linalg::comprehensive_bufferize::std_ext:: - registerModuleBufferizationExternalModels(DialectRegistry ®istry) { - registry.addOpInterface(); - registry.addOpInterface(); - registry.addOpInterface(); -} - -/// Set the attribute that triggers inplace bufferization on a FuncOp argument -/// `bbArg`. -static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { - auto funcOp = cast(bbArg.getOwner()->getParentOp()); - funcOp.setArgAttr(bbArg.getArgNumber(), - BufferizableOpInterface::kInplaceableAttrName, - BoolAttr::get(bbArg.getContext(), inPlace)); -} - -/// Annotate the IR with the result of the analysis. For testing/debugging only. -static void annotateOpsWithBufferizationMarkers(FuncOp funcOp, - const AnalysisState &state) { - auto bufferizableOp = cast(funcOp.getOperation()); - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) - setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); -} - -/// Fold return values that are memref casts. -static void foldMemRefCasts(FuncOp funcOp) { - if (funcOp.getBody().empty()) - return; - - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - SmallVector resultTypes; - - for (OpOperand &operand : returnOp->getOpOperands()) { - if (auto castOp = operand.get().getDefiningOp()) { - operand.set(castOp.source()); - resultTypes.push_back(castOp.source().getType()); - } else { - resultTypes.push_back(operand.get().getType()); - } - } - - auto newFuncType = FunctionType::get( - funcOp.getContext(), funcOp.getType().getInputs(), resultTypes); - funcOp.setType(newFuncType); -} - -LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( - ModuleOp moduleOp, OneShotBufferizationOptions options) { - IRRewriter rewriter(moduleOp.getContext()); - OneShotAnalysisState analysisState(moduleOp, options); - BufferizationState bufferizationState(analysisState); - FuncAnalysisState &moduleState = getFuncAnalysisState(analysisState); - BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); - - // A list of functions in the order in which they are analyzed + bufferized. - SmallVector orderedFuncOps; - - // A mapping of FuncOps to their callers. - FuncCallerMap callerMap; - - if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) - return failure(); - - // Collect bbArg/return value information after the analysis. - options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); - options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); - - // Analyze ops. - for (FuncOp funcOp : orderedFuncOps) { - // No body => no analysis. - if (funcOp.getBody().empty()) - continue; - - // Now analyzing function. - moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress; - - // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, aliasInfo, moduleState); - - // Analyze funcOp. - if (failed(analyzeOp(funcOp, analysisState))) - return failure(); - - // Mark op as fully analyzed. - moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; - - // Add annotations to function arguments. - if (options.testAnalysisOnly) - annotateOpsWithBufferizationMarkers(funcOp, analysisState); - } - - if (options.testAnalysisOnly) - return success(); - - // Bufferize functions. - for (FuncOp funcOp : orderedFuncOps) { - // Note: It would be good to apply cleanups here but we cannot as aliasInfo - // would be invalidated. - if (failed(bufferizeOp(funcOp, bufferizationState))) - return failure(); - foldMemRefCasts(funcOp); - } - - // Check result. - for (FuncOp funcOp : orderedFuncOps) { - if (!options.allowReturnAllocs && - llvm::any_of(funcOp.getType().getResults(), [](Type t) { - return t.isa(); - })) { - funcOp->emitError("memref return type is unsupported"); - return failure(); - } - } - - // Finalize all buffers. - if (failed(finalizeBuffers(moduleOp, options))) - return failure(); - - // Post-pass cleanup of inplaceable and buffer_layout attributes. - moduleOp.walk([&](FuncOp op) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationFuncArguments(bbArg); - }); - - return success(); -} diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -133,17 +133,6 @@ << " to be used on function-like operations"; return success(); } - if (attr.getName() == BufferizableOpInterface::kBufferLayoutAttrName) { - if (!attr.getValue().isa()) { - return op->emitError() - << "'" << BufferizableOpInterface::kBufferLayoutAttrName - << "' is expected to be a affine map attribute"; - } - if (!isa(op)) - return op->emitError() << "expected " << attr.getName() - << " to be used on function-like operations"; - return success(); - } if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) return success(); return op->emitError() << "attribute '" << attr.getName() diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -37,6 +37,7 @@ MLIRArithmetic MLIRArithmeticTransforms MLIRBufferization + MLIRBufferizationTransforms MLIRComplex MLIRFunc MLIRFuncToLLVM @@ -47,7 +48,6 @@ MLIRLinalg MLIRLinalgAnalysis MLIRLinalgUtils - MLIRModuleBufferization MLIRSCF MLIRSCFTransforms MLIRSCFUtils diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -11,10 +11,11 @@ #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" @@ -28,7 +29,6 @@ using namespace mlir; using namespace mlir::bufferization; using namespace mlir::linalg; -using namespace mlir::linalg::comprehensive_bufferize; namespace { struct LinalgComprehensiveModuleBufferize @@ -55,7 +55,7 @@ bufferization::registerAllocationOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); - std_ext::registerModuleBufferizationExternalModels(registry); + func_ext::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } @@ -109,7 +109,7 @@ ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); - if (failed(runModuleBufferize(moduleOp, opt))) { + if (failed(runOneShotModuleBufferize(moduleOp, opt))) { signalPassFailure(); return; } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir @@ -205,8 +205,8 @@ // ----- // CHECK-SCF-LABEL: func @simple_scf_if( -// CHECK-SCF-SAME: %[[t1:.*]]: tensor {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index -func @simple_scf_if(%t1: tensor {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32) +// CHECK-SCF-SAME: %[[t1:.*]]: tensor {bufferization.writable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index +func @simple_scf_if(%t1: tensor {bufferization.writable = true}, %c: i1, %pos: index, %f: f32) -> (tensor, index) { // CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref) { %r1, %r2 = scf.if %c -> (tensor, index) { diff --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir rename from mlir/test/Dialect/Linalg/one-shot-module-bufferize-allow-return-allocs.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir --- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir @@ -1,12 +1,12 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-allocs -split-input-file | FileCheck %s +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // Test bufferization using memref types that have no layout map. -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs fully-dynamic-layout-maps=0" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs fully-dynamic-layout-maps=0" -split-input-file -o /dev/null // Make sure that the returned buffer is not deallocated. // TODO: Such buffers currently leak. We need buffer hoisting / ref counting for diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1,9 +1,12 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs analysis-fuzzer-seed=23" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs analysis-fuzzer-seed=59" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs analysis-fuzzer-seed=91" -split-input-file -o /dev/null + +// TODO: Extract op-specific test cases and move them to their respective +// dialects. //===----------------------------------------------------------------------===// // Simple cases @@ -12,9 +15,9 @@ // ----- // CHECK-LABEL: func @extract_slice_fun( -func @extract_slice_fun(%A : tensor {linalg.inplaceable = false}, +func @extract_slice_fun(%A : tensor {bufferization.writable = false}, // CHECK-SAME: bufferization.access = "read" - %B : tensor {linalg.inplaceable = true}) + %B : tensor {bufferization.writable = true}) // CHECK-SAME: bufferization.access = "read" -> (tensor<4xf32>, tensor<8xf32>) { @@ -36,11 +39,11 @@ // ----- // CHECK-LABEL: func @insert_slice_fun( -func @insert_slice_fun(%A : tensor {linalg.inplaceable = false}, +func @insert_slice_fun(%A : tensor {bufferization.writable = false}, // CHECK-SAME: bufferization.access = "read" - %B : tensor {linalg.inplaceable = true}, + %B : tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read-write" - %C : tensor<4xf32> {linalg.inplaceable = false}) + %C : tensor<4xf32> {bufferization.writable = false}) // CHECK-SAME: bufferization.access = "read" -> (tensor, tensor) { @@ -62,9 +65,9 @@ // ----- // CHECK-LABEL: func @conflict_on_B( -func @conflict_on_B(%A : tensor<4x4xf32> {linalg.inplaceable = true}, +func @conflict_on_B(%A : tensor<4x4xf32> {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read" - %B : tensor<4x4xf32> {linalg.inplaceable = true}) + %B : tensor<4x4xf32> {bufferization.writable = true}) // CHECK-SAME: bufferization.access = "read-write" -> (tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>) { @@ -102,9 +105,9 @@ // CHECK-LABEL: func @extract_slice_extract_slice( func @extract_slice_extract_slice( - %A : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read" - %B : tensor {linalg.inplaceable = false}) + %B : tensor {bufferization.writable = false}) // CHECK-SAME: bufferization.access = "read" -> (tensor<2xf32>, tensor<2xf32>) { @@ -131,17 +134,17 @@ // CHECK-LABEL: func @insert_slice_insert_slice( func @insert_slice_insert_slice( - %A : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read-write" - %A2 : tensor<4xf32> {linalg.inplaceable = true}, + %A2 : tensor<4xf32> {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read-write" - %A3 : tensor<2xf32> {linalg.inplaceable = true}, + %A3 : tensor<2xf32> {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read" - %B : tensor {linalg.inplaceable = false}, + %B : tensor {bufferization.writable = false}, // CHECK-SAME: bufferization.access = "read" - %B2 : tensor<4xf32> {linalg.inplaceable = false}, + %B2 : tensor<4xf32> {bufferization.writable = false}, // CHECK-SAME: bufferization.access = "read" - %B3 : tensor<2xf32> {linalg.inplaceable = false}) + %B3 : tensor<2xf32> {bufferization.writable = false}) // CHECK-SAME: bufferization.access = "read" -> (tensor, tensor) { @@ -166,8 +169,8 @@ // CHECK-LABEL: func @extract_slice_nonmatching_insert_slice func @extract_slice_nonmatching_insert_slice( - %A : tensor {linalg.inplaceable = true}, - %B : tensor {linalg.inplaceable = false}, + %A : tensor {bufferization.writable = true}, + %B : tensor {bufferization.writable = false}, %idx: index) -> (tensor, tensor) { @@ -205,8 +208,8 @@ // CHECK-LABEL: func @extract_slice_matching_insert_slice func @extract_slice_matching_insert_slice( - %A : tensor {linalg.inplaceable = true}, - %B : tensor {linalg.inplaceable = false}) + %A : tensor {bufferization.writable = true}, + %B : tensor {bufferization.writable = false}) -> (tensor, tensor) { // %r1 bufferizes inplace because %A is inplaceable. @@ -243,7 +246,7 @@ // CHECK-LABEL: @read_of_matching_insert_slice_source func @read_of_matching_insert_slice_source( - %A : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = true}, %idx : index, %idx2 : index) -> (tensor, vector<5xf32>) @@ -274,7 +277,7 @@ // CHECK-LABEL: @read_of_matching_insert_slice_source_interleaved func @read_of_matching_insert_slice_source_interleaved( - %A : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = true}, %idx : index, %idx2 : index, %idx3 : index) @@ -318,9 +321,9 @@ // CHECK-LABEL: func @extract_slice_linalg_readonly_use func @extract_slice_linalg_readonly_use( - %A : tensor {linalg.inplaceable = false}, - %B : tensor<4x4xf32> {linalg.inplaceable = false}, - %C : tensor<4x4xf32> {linalg.inplaceable = true}) + %A : tensor {bufferization.writable = false}, + %B : tensor<4x4xf32> {bufferization.writable = false}, + %C : tensor<4x4xf32> {bufferization.writable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { // tensor.extract_slice is only used as a read, no interference irrespective @@ -352,9 +355,9 @@ // CHECK-LABEL: func @extract_slice_to_linalg_write_use func @extract_slice_to_linalg_write_use( - %A : tensor<4x4xf32> {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = false}, - %C : tensor {linalg.inplaceable = true}) + %A : tensor<4x4xf32> {bufferization.writable = false}, + %B : tensor {bufferization.writable = false}, + %C : tensor {bufferization.writable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { // Step 4. %sB forward propagates to a write in %D but it is not inplace. @@ -396,9 +399,9 @@ %s2: index, %s3: index, %s4: index, - %A: tensor<8x6xf32> {linalg.inplaceable = false}, - %B: tensor<6x6xf32> {linalg.inplaceable = false}, - %C: tensor<30x20xf32> {linalg.inplaceable = true}) + %A: tensor<8x6xf32> {bufferization.writable = false}, + %B: tensor<6x6xf32> {bufferization.writable = false}, + %C: tensor<30x20xf32> {bufferization.writable = true}) -> tensor<30x20xf32> { // CHECK: tensor.extract_slice @@ -430,9 +433,9 @@ // CHECK-LABEL: func @extract_slice_to_linalg_write_use func @extract_slice_to_linalg_write_use( - %A : tensor<4x4xf32> {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = false}, - %C : tensor {linalg.inplaceable = true}) + %A : tensor<4x4xf32> {bufferization.writable = false}, + %B : tensor {bufferization.writable = false}, + %C : tensor {bufferization.writable = true}) -> (tensor<4x4xf32>, tensor<4x4xf32>) { // Step 4. %sB forward propagates to an inplace write in %D. @@ -472,9 +475,9 @@ // CHECK-LABEL: func @nested_extract_slice_and_insert func @nested_extract_slice_and_insert( - %A : tensor {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = true}, - %C : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = false}, + %B : tensor {bufferization.writable = true}, + %C : tensor {bufferization.writable = true}, %idx : index, %sz1 : index, %sz2 : index) @@ -564,8 +567,8 @@ // CHECK-LABEL: func @scf_for_yield_only func @scf_for_yield_only( - %A : tensor {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = false}, + %B : tensor {bufferization.writable = true}, %lb : index, %ub : index, %step : index) @@ -596,9 +599,9 @@ // CHECK-LABEL: func @scf_for_with_tensor.insert_slice func @scf_for_with_tensor.insert_slice( - %A : tensor {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = true}, - %C : tensor<4xf32> {linalg.inplaceable = false}, + %A : tensor {bufferization.writable = false}, + %B : tensor {bufferization.writable = true}, + %C : tensor<4xf32> {bufferization.writable = false}, %lb : index, %ub : index, %step : index) @@ -634,8 +637,8 @@ // CHECK-LABEL: func @scf_for_deps func @scf_for_deps( - %A : tensor {linalg.inplaceable = true}, - %B : tensor {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = true}, + %B : tensor {bufferization.writable = true}, %lb : index, %ub : index, %step : index) @@ -680,7 +683,7 @@ func private @foo(tensor<64xf32>) // CHECK-LABEL: dependence_through_call -func @dependence_through_call(%I : tensor<64xf32> {linalg.inplaceable = true}) { +func @dependence_through_call(%I : tensor<64xf32> {bufferization.writable = true}) { %f1 = arith.constant 1.000000e+00 : f32 %f2 = arith.constant 2.000000e+00 : f32 @@ -712,8 +715,8 @@ } func @read_dependence_through_scf_and_call( - %I : tensor<64xf32> {linalg.inplaceable = true}, - %I2 : tensor<64xf32> {linalg.inplaceable = true}) { + %I : tensor<64xf32> {bufferization.writable = true}, + %I2 : tensor<64xf32> {bufferization.writable = true}) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -785,9 +788,9 @@ // ----- builtin.func @matmul_on_tensors( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -823,9 +826,9 @@ // ----- builtin.func @matmul_on_tensors( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -878,11 +881,11 @@ func @insert_slice_chain( %v1: vector<32x90xf32>, %v2: vector<30x90xf32>, - %arg0: tensor<62x126xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg0: tensor<62x126xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, // CHECK-SAME: bufferization.access = "none" - %arg1: tensor<126x90xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<126x90xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, // CHECK-SAME: bufferization.access = "none" - %arg2: tensor<62x90xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg2: tensor<62x90xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = true}) // CHECK-SAME: bufferization.access = "write" -> tensor<62x90xf32> attributes {passthrough = [["target-cpu", "skylake-avx512"], ["prefer-vector-width", "512"]]} { @@ -926,7 +929,7 @@ // Only test IR validity wrt dominance. // CHECK-LABEL: func @ip -func @ip(%t: tensor<10x20xf32> {linalg.inplaceable = true}, +func @ip(%t: tensor<10x20xf32> {bufferization.writable = true}, %x: index, %y: index, %v: vector<5x6xf32>) -> tensor<10x20xf32> { @@ -960,9 +963,9 @@ // CHECK-LABEL: func @linalg_op_same_out_tensors( func @linalg_op_same_out_tensors( - %t1: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read" - %t2: tensor {linalg.inplaceable = true}) + %t2: tensor {bufferization.writable = true}) // CHECK-SAME: bufferization.access = "write" -> (tensor, tensor){ @@ -994,9 +997,9 @@ // CHECK-LABEL: func @linalg_op_same_out_tensors_2( func @linalg_op_same_out_tensors_2( - %t1: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read" - %t2: tensor {linalg.inplaceable = true}) + %t2: tensor {bufferization.writable = true}) // CHECK-SAME: bufferization.access = "write" -> (tensor, tensor, tensor){ @@ -1020,7 +1023,7 @@ func @double_insert_slice_into_alias( %v1: vector<32x90xf32>, %v2: vector<30x90xf32>, - %arg2: tensor<62x90xf32> {linalg.inplaceable = true}, + %arg2: tensor<62x90xf32> {bufferization.writable = true}, %s1: index, %s2: index, %s3: index, %s4: index) -> (tensor<62x90xf32>, tensor) { @@ -1061,7 +1064,7 @@ // CHECK-LABEL: func @interleaved_extract_insert_slice_chain_1 func @interleaved_extract_insert_slice_chain_1( - %arg2: tensor<62x90xf32> {linalg.inplaceable = true}) + %arg2: tensor<62x90xf32> {bufferization.writable = true}) -> (tensor<62x90xf32>) { // CHECK: tensor.extract_slice @@ -1092,7 +1095,7 @@ // CHECK-LABEL: func @interleaved_extract_insert_slice_chain_2 func @interleaved_extract_insert_slice_chain_2( - %arg2: tensor<62x90xf32> {linalg.inplaceable = true}) + %arg2: tensor<62x90xf32> {bufferization.writable = true}) -> (tensor<62x90xf32>) { // CHECK: tensor.extract_slice @@ -1123,7 +1126,7 @@ // CHECK-LABEL: func @extract_once_insert_twice func @extract_once_insert_twice( - %arg2: tensor<62x90xf32> {linalg.inplaceable = true}) + %arg2: tensor<62x90xf32> {bufferization.writable = true}) -> (tensor<62x90xf32>) { // CHECK: tensor.extract_slice @@ -1154,7 +1157,7 @@ } // CHECK-LABEL: func @reading_scf_for -func @reading_scf_for(%t1: tensor {linalg.inplaceable = true}, +func @reading_scf_for(%t1: tensor {bufferization.writable = true}, %s: index, %v: vector<5xf32>) -> (tensor, vector<5xf32>) { %c0 = arith.constant 0 : index @@ -1201,7 +1204,7 @@ } // CHECK-LABEL: func @non_reading_scf_for -func @non_reading_scf_for(%t1: tensor {linalg.inplaceable = true}, +func @non_reading_scf_for(%t1: tensor {bufferization.writable = true}, %s: index, %v: vector<5xf32>) -> (tensor, vector<5xf32>) { %c0 = arith.constant 0 : index @@ -1250,8 +1253,8 @@ // This example passes analysis, but it fails when bufferizing. // CHECK-LABEL: func @scf_if_inplace1 -func @scf_if_inplace1(%t1: tensor {linalg.inplaceable = true}, - %t2: tensor {linalg.inplaceable = true}, +func @scf_if_inplace1(%t1: tensor {bufferization.writable = true}, + %t2: tensor {bufferization.writable = true}, %cond: i1) -> tensor { %r = scf.if %cond -> (tensor) { // CHECK: scf.yield @@ -1268,7 +1271,7 @@ // ----- // CHECK-LABEL: func @scf_if_inplace2 -func @scf_if_inplace2(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_inplace2(%t1: tensor {bufferization.writable = true}, %v: vector<5xf32>, %idx: index, %cond: i1) -> tensor { %r = scf.if %cond -> (tensor) { @@ -1289,7 +1292,7 @@ // ----- // CHECK-LABEL: func @scf_if_inplace3 -func @scf_if_inplace3(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_inplace3(%t1: tensor {bufferization.writable = true}, %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index, %cond: i1) -> tensor { // CHECK: tensor.extract_slice @@ -1317,7 +1320,7 @@ // ----- // CHECK-LABEL: func @scf_if_in_place4 -func @scf_if_in_place4(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_in_place4(%t1: tensor {bufferization.writable = true}, %v: vector<5xf32>, %idx: index, %cond: i1, %cond2: i1) -> (tensor, vector<10xf32>) { %cst = arith.constant 0.0 : f32 @@ -1353,7 +1356,7 @@ // ----- // CHECK-LABEL: func @scf_if_inplace5 -func @scf_if_inplace5(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_inplace5(%t1: tensor {bufferization.writable = true}, %idx: index, %cond: i1) -> tensor { %r = scf.if %cond -> (tensor) { // CHECK: tensor.extract_slice @@ -1385,7 +1388,7 @@ // ----- // CHECK-LABEL: func @scf_if_inplace6 -func @scf_if_inplace6(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_inplace6(%t1: tensor {bufferization.writable = true}, %v1: vector<5xf32>, %v2: vector<5xf32>, %v3: vector<5xf32>, %idx: index, %cond: i1, %cond2: i1) -> tensor { @@ -1426,7 +1429,7 @@ // ----- // CHECK-LABEL: func @scf_if_inplace7 -func @scf_if_inplace7(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_inplace7(%t1: tensor {bufferization.writable = true}, %v1: vector<5xf32>, %v2: vector<5xf32>, %idx: index, %idx2: index, %cond: i1) -> (tensor, vector<5xf32>) { %cst = arith.constant 0.0 : f32 @@ -1456,7 +1459,7 @@ // ----- // CHECK-LABEL: func @scf_if_out_of_place1a -func @scf_if_out_of_place1a(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_out_of_place1a(%t1: tensor {bufferization.writable = true}, %idx: index, %idx2: index, %cond: i1) -> tensor { %r = scf.if %cond -> (tensor) { @@ -1483,7 +1486,7 @@ // ----- // CHECK-LABEL: func @scf_if_out_of_place1b -func @scf_if_out_of_place1b(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_out_of_place1b(%t1: tensor {bufferization.writable = true}, %idx: index, %idx2: index, %idx3: index, %cond: i1) -> tensor { %r = scf.if %cond -> (tensor) { @@ -1519,7 +1522,7 @@ // ----- // CHECK-LABEL: func @scf_if_out_of_place1c -func @scf_if_out_of_place1c(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_out_of_place1c(%t1: tensor {bufferization.writable = true}, %idx: index, %idx2: index, %cond: i1) -> tensor { %r = scf.if %cond -> (tensor) { // CHECK: tensor.extract_slice @@ -1550,7 +1553,7 @@ // ----- // CHECK-LABEL: func @scf_if_out_of_place2 -func @scf_if_out_of_place2(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_out_of_place2(%t1: tensor {bufferization.writable = true}, %v: vector<5xf32>, %idx: index, %cond: i1) -> (tensor, vector<10xf32>) { %cst = arith.constant 0.0 : f32 @@ -1574,7 +1577,7 @@ // ----- // CHECK-LABEL: func @scf_if_out_of_place3 -func @scf_if_out_of_place3(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_out_of_place3(%t1: tensor {bufferization.writable = true}, %v: vector<5xf32>, %idx: index, %cond: i1, %cond2: i1) -> (tensor, vector<10xf32>) { %cst = arith.constant 0.0 : f32 @@ -1605,7 +1608,7 @@ // ----- // CHECK-LABEL: func @some_use -func @some_use(%A : tensor {linalg.inplaceable = true}, +func @some_use(%A : tensor {bufferization.writable = true}, %v : vector<5xf32>) -> (tensor) { %idx = arith.constant 0 : index // CHECK: vector.transfer_write @@ -1616,7 +1619,7 @@ // CHECK-LABEL: func @main_func -func @main_func(%A : tensor {linalg.inplaceable = true}, +func @main_func(%A : tensor {bufferization.writable = true}, %v : vector<5xf32>) -> (tensor) { // CHECK: call // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"] @@ -1646,7 +1649,7 @@ // ----- // CHECK-LABEL: func @to_memref_op_is_reading -func @to_memref_op_is_reading(%t1: tensor {linalg.inplaceable = true}, +func @to_memref_op_is_reading(%t1: tensor {bufferization.writable = true}, %idx1: index, %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) { @@ -1711,8 +1714,8 @@ // CHECK-LABEL: func @write_after_select_read_one // CHECK-SAME: %[[t1:.*]]: tensor {{.*}}, %[[t2:.*]]: tensor func @write_after_select_read_one( - %t1 : tensor {linalg.inplaceable = true}, - %t2 : tensor {linalg.inplaceable = true}, + %t1 : tensor {bufferization.writable = true}, + %t2 : tensor {bufferization.writable = true}, %c : i1) -> (f32, tensor) { @@ -1737,8 +1740,8 @@ // CHECK-LABEL: func @write_after_select_read_both // CHECK-SAME: %[[t1:.*]]: tensor {{.*}}, %[[t2:.*]]: tensor func @write_after_select_read_both( - %t1 : tensor {linalg.inplaceable = true}, - %t2 : tensor {linalg.inplaceable = true}, + %t1 : tensor {bufferization.writable = true}, + %t2 : tensor {bufferization.writable = true}, %c : i1) -> (f32, f32, tensor) { @@ -1766,8 +1769,8 @@ // CHECK-LABEL: func @write_after_select_no_conflict // CHECK-SAME: %[[t1:.*]]: tensor {{.*}}, %[[t2:.*]]: tensor func @write_after_select_no_conflict( - %t1 : tensor {linalg.inplaceable = true}, - %t2 : tensor {linalg.inplaceable = true}, + %t1 : tensor {bufferization.writable = true}, + %t2 : tensor {bufferization.writable = true}, %c : i1) -> (f32, tensor) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics +// RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics func private @foo() -> tensor @@ -37,7 +37,7 @@ // ----- func @scf_if_not_equivalent( - %cond: i1, %t1: tensor {linalg.inplaceable = true}, + %cond: i1, %t1: tensor {bufferization.writable = true}, %idx: index) -> tensor { %r = scf.if %cond -> (tensor) { scf.yield %t1 : tensor @@ -54,7 +54,7 @@ // ----- func @scf_if_not_aliasing( - %cond: i1, %t1: tensor {linalg.inplaceable = true}, + %cond: i1, %t1: tensor {bufferization.writable = true}, %idx: index) -> f32 { %r = scf.if %cond -> (tensor) { scf.yield %t1 : tensor @@ -85,7 +85,7 @@ // ----- func @scf_for(%A : tensor, - %B : tensor {linalg.inplaceable = true}, + %B : tensor {bufferization.writable = true}, %C : tensor<4xf32>, %lb : index, %ub : index, %step : index) -> (f32, f32) @@ -110,14 +110,14 @@ // ----- -func private @fun_with_side_effects(%A: tensor {linalg.inplaceable = true}) +func private @fun_with_side_effects(%A: tensor {bufferization.writable = true}) -func @foo(%A: tensor {linalg.inplaceable = true}) -> (tensor) { +func @foo(%A: tensor {bufferization.writable = true}) -> (tensor) { call @fun_with_side_effects(%A) : (tensor) -> () return %A: tensor } -func @scf_yield_needs_copy(%A : tensor {linalg.inplaceable = true}, %iters : index) { +func @scf_yield_needs_copy(%A : tensor {bufferization.writable = true}, %iters : index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor) { @@ -131,7 +131,7 @@ // ----- -func @extract_slice_fun(%A : tensor {linalg.inplaceable = true}) +func @extract_slice_fun(%A : tensor {bufferization.writable = true}) -> tensor<4xf32> { // This bufferizes to a pattern that the cross-function boundary pass needs to @@ -184,6 +184,7 @@ func @main() -> tensor<4xi32> { %r = scf.execute_region -> tensor<4xi32> { %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} scf.yield %A: tensor<4xi32> } @@ -194,7 +195,7 @@ // ----- func @to_memref_op_is_writing( - %t1: tensor {linalg.inplaceable = true}, %idx1: index, + %t1: tensor {bufferization.writable = true}, %idx1: index, %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) { // This is a RaW conflict because to_memref is an inplace write and %t1 is // read further down. This will likely have to change with partial diff --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir rename from mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -1,12 +1,12 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // Test bufferization using memref types that have no layout map. -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs fully-dynamic-layout-maps=0" -split-input-file -o /dev/null +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs fully-dynamic-layout-maps=0" -split-input-file | FileCheck %s --check-prefix=CHECK-NO-LAYOUT-MAP-LABEL // Bufferization of bodiless function with no tensor return value. @@ -31,7 +31,7 @@ // CHECK-NOT: alloc // CHECK-NOT: copy // CHECK: call @private_func(%[[t]]) -func @main(%t: tensor {linalg.inplaceable = true}) -> (f32) { +func @main(%t: tensor {bufferization.writable = true}) -> (f32) { %0 = call @private_func(%t) : (tensor) -> (f32) return %0 : f32 } @@ -50,7 +50,7 @@ // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: call @private_func(%[[casted]]) // CHECK: memref.dealloc %[[alloc]] -func @main(%t: tensor {linalg.inplaceable = false}) -> (f32) { +func @main(%t: tensor {bufferization.writable = false}) -> (f32) { %0 = call @private_func(%t) : (tensor) -> (f32) return %0 : f32 } @@ -99,7 +99,7 @@ // CHECK-LABEL: func @call_func_with_non_tensor_return( // CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = true}) -> (f32, tensor) { + %t0: tensor {bufferization.writable = true}) -> (f32, tensor) { // CHECK-NOT: alloc // CHECK-NOT: copy // CHECK: %[[call:.*]] = call @inner_func(%[[arg0]]) @@ -131,7 +131,7 @@ // CHECK-LABEL: func @call_func_with_non_tensor_return( // CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = false}) -> (f32, tensor) { + %t0: tensor {bufferization.writable = false}) -> (f32, tensor) { // CHECK: %[[alloc:.*]] = memref.alloc // CHECK-DAG: memref.copy %[[arg0]], %[[alloc]] // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]] @@ -177,7 +177,7 @@ // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: call @f2(%[[casted]]) // CHECK: memref.dealloc %[[alloc]] -func @main(%t: tensor {linalg.inplaceable = false}) -> (f32) { +func @main(%t: tensor {bufferization.writable = false}) -> (f32) { %0 = call @f2(%t) : (tensor) -> (f32) return %0 : f32 } @@ -204,7 +204,7 @@ // CHECK: call @does_not_read(%[[casted]]) // CHECK: %[[r:.*]] = memref.load %[[alloc]] // CHECK: memref.dealloc %[[alloc]] -func @main(%t: tensor {linalg.inplaceable = false}) -> f32 { +func @main(%t: tensor {bufferization.writable = false}) -> f32 { %0 = call @does_not_read(%t) : (tensor) -> (tensor) %idx = arith.constant 4 : index %r = tensor.extract %0[%idx] : tensor @@ -337,9 +337,9 @@ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]> func @bar( - %A : tensor {linalg.inplaceable = true}, - %B : tensor {linalg.inplaceable = true}, - %C : tensor<4xf32> {linalg.inplaceable = true}, + %A : tensor {bufferization.writable = true}, + %B : tensor {bufferization.writable = true}, + %C : tensor<4xf32> {bufferization.writable = true}, %lb : index, %ub : index, %step : index) -> (tensor, tensor) { @@ -440,7 +440,7 @@ // CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref -func @callee(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, +func @callee(%A : tensor {bufferization.buffer_layout = affine_map<(i)[s0, s1] -> (i)>}, %B : tensor, %C : tensor) { // CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref to memref @@ -460,9 +460,9 @@ // CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref -func @entry(%A : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, - %B : tensor {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, linalg.inplaceable = false}, - %C : tensor {linalg.inplaceable = false}) { +func @entry(%A : tensor {bufferization.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, bufferization.writable = false}, + %B : tensor {bufferization.buffer_layout = affine_map<(i)[s0, s1] -> (i)>, bufferization.writable = false}, + %C : tensor {bufferization.writable = false}) { // Note: `callee` does not write to its bbArg directly, but `external_func` // does. Inside `callee`, the writes via `external_func` do not cause a // conflict. However, inside `entry`, the writes do cause a conflict because @@ -498,7 +498,7 @@ // CHECK-LABEL: func @equivalent_func_arg( // CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = true}, +func @equivalent_func_arg(%t0: tensor {bufferization.writable = true}, %c0: index, %c10: index, %c1: index) -> tensor { // CHECK-NOT: alloc // CHECK-NOT: copy @@ -527,7 +527,7 @@ // CHECK-LABEL: func @equivalent_func_arg_2( // CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = true}, +func @equivalent_func_arg_2(%t0: tensor {bufferization.writable = true}, %c0: index, %c10: index, %c1: index) -> tensor { // CHECK: scf.for {{.*}} { %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { @@ -542,3 +542,23 @@ } return %1: tensor } + +// ----- + +// Bufferize without fully dynamic layout maps. + +// CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { +// CHECK-NO-LAYOUT-MAP-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> +func @transfer_read( + %A : tensor {bufferization.writable = false}) + -> (vector<4xf32>) +{ + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + +// CHECK: %[[RES:.*]] = vector.transfer_read {{.*}} : memref, vector<4xf32> + %0 = vector.transfer_read %A[%c0], %f0 : tensor, vector<4xf32> + +// CHECK: return %[[RES]] : vector<4xf32> + return %0 : vector<4xf32> +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir b/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-bufferize-analysis-2fill-extract-matmul-all-perms.mlir @@ -7,9 +7,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_1234( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -34,9 +34,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_1243( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -60,9 +60,9 @@ // ----- // CHECK-LABEL: func @fill_extract_matmul_ -func @fill_extract_matmul_1324(%arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) +func @fill_extract_matmul_1324(%arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -86,9 +86,9 @@ // ----- // CHECK-LABEL: func @fill_extract_matmul_ -func @fill_extract_matmul_1342(%arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) +func @fill_extract_matmul_1342(%arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -112,9 +112,9 @@ // ----- // CHECK-LABEL: func @fill_extract_matmul_ -func @fill_extract_matmul_1423(%arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) +func @fill_extract_matmul_1423(%arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -138,9 +138,9 @@ // ----- // CHECK-LABEL: func @fill_extract_matmul_ -func @fill_extract_matmul_1432(%arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) +func @fill_extract_matmul_1432(%arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -165,9 +165,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_2134( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -192,9 +192,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_2143( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -219,9 +219,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_2314( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -246,9 +246,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_2341( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -273,9 +273,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_2413( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -300,9 +300,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_2431( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -327,9 +327,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_3124( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -354,9 +354,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_3142( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -381,9 +381,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_3214( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -408,9 +408,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_3241( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -435,9 +435,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_3412( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -462,9 +462,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_3421( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -489,9 +489,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_4123( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -516,9 +516,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_4132( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -543,9 +543,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_4213( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -570,9 +570,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_4231( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -597,9 +597,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_4312( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index @@ -624,9 +624,9 @@ // CHECK-LABEL: func @fill_extract_matmul_ func @fill_extract_matmul_4321( - %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %arg2: tensor<256x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %arg0: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg1: tensor<518x518xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, + %arg2: tensor<256x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) -> tensor<256x256xf32> { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-aliasing-in.mlir @@ -3,9 +3,9 @@ // CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_bufferizes_inplace_with_input( - %t1: tensor {linalg.inplaceable = true}, - %t2: tensor {linalg.inplaceable = false}, - %t3: tensor {linalg.inplaceable = false}, + %t1: tensor {bufferization.writable = true}, + %t2: tensor {bufferization.writable = false}, + %t3: tensor {bufferization.writable = false}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t1]] : {{.*}}) %r = linalg.generic { @@ -27,9 +27,9 @@ // CHECK-LABEL: func @linalg_op_bufferizes_out_of_place_with_input // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_bufferizes_out_of_place_with_input( - %t1: tensor {linalg.inplaceable = false}, - %t2: tensor {linalg.inplaceable = false}, - %t3: tensor {linalg.inplaceable = false}, + %t1: tensor {bufferization.writable = false}, + %t2: tensor {bufferization.writable = false}, + %t3: tensor {bufferization.writable = false}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: %[[alloc:.*]] = memref.alloc // CHECK: memref.copy %[[t1]], %[[alloc]] @@ -54,9 +54,9 @@ // CHECK-LABEL: func @linalg_op_output_cannot_alias_with_input // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_output_cannot_alias_with_input( - %t1: tensor {linalg.inplaceable = true}, - %t2: tensor {linalg.inplaceable = false}, - %t3: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, + %t2: tensor {bufferization.writable = false}, + %t3: tensor {bufferization.writable = true}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t3]] : {{.*}}) %r = linalg.generic { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-aliasing-in.mlir @@ -16,9 +16,9 @@ // CHECK-LABEL: func @linalg_op_same_out_tensors( func @linalg_op_same_out_tensors( - %t1: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read-write" - %t2: tensor {linalg.inplaceable = true}) + %t2: tensor {bufferization.writable = true}) // CHECK-SAME: bufferization.access = "write" -> (tensor, tensor){ @@ -54,9 +54,9 @@ // CHECK-LABEL: func @linalg_op_same_out_tensors_2( func @linalg_op_same_out_tensors_2( - %t1: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, // CHECK-SAME: bufferization.access = "read-write" - %t2: tensor {linalg.inplaceable = true}) + %t2: tensor {bufferization.writable = true}) // CHECK-SAME: bufferization.access = "write" -> (tensor, tensor, tensor){ diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // CHECK-LABEL: func @buffer_forwarding_conflict -func @buffer_forwarding_conflict(%arg0: tensor {linalg.inplaceable = true}, %arg1: index) -> (tensor, tensor) { +func @buffer_forwarding_conflict(%arg0: tensor {bufferization.writable = true}, %arg1: index) -> (tensor, tensor) { %cst = arith.constant 0.000000e+00 : f32 // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["false", "none"] @@ -34,7 +34,7 @@ // ----- // CHECK-LABEL: func @buffer_forwarding_no_conflict -func @buffer_forwarding_no_conflict(%arg0: tensor {linalg.inplaceable = true}, %arg1: index) -> (tensor, tensor) { +func @buffer_forwarding_no_conflict(%arg0: tensor {bufferization.writable = true}, %arg1: index) -> (tensor, tensor) { %cst = arith.constant 0.000000e+00 : f32 // CHECK: tensor.extract_slice // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"] diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir @@ -6,7 +6,7 @@ // CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index func @buffer_forwarding_conflict( - %t: tensor {linalg.buffer_layout = affine_map<(d0) -> (d0)>, linalg.inplaceable = true}, + %t: tensor {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, %sz: index) -> (tensor, tensor) { @@ -43,7 +43,7 @@ // CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref // CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index func @buffer_forwarding_no_conflict( - %t: tensor {linalg.buffer_layout = affine_map<(d0) -> (d0)>, linalg.inplaceable = true}, + %t: tensor {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, %sz: index) -> (tensor) { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -8,31 +8,13 @@ // Test bufferization using memref types that have no layout map. // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs fully-dynamic-layout-maps=0" -split-input-file | FileCheck %s --check-prefix=CHECK-NO-LAYOUT-MAP -// CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { -// CHECK-NO-LAYOUT-MAP-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> -func @transfer_read( - %A : tensor {linalg.inplaceable = false}) - -> (vector<4xf32>) -{ - %c0 = arith.constant 0 : index - %f0 = arith.constant 0.0 : f32 - -// CHECK: %[[RES:.*]] = vector.transfer_read {{.*}} : memref, vector<4xf32> - %0 = vector.transfer_read %A[%c0], %f0 : tensor, vector<4xf32> - -// CHECK: return %[[RES]] : vector<4xf32> - return %0 : vector<4xf32> -} - -// ----- - // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-LABEL: func @fill_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-NO-LAYOUT-MAP-LABEL: func @fill_inplace(%{{.*}}: memref) { func @fill_inplace( - %A : tensor {linalg.inplaceable = true}) + %A : tensor {bufferization.writable = true}) -> tensor { // CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 @@ -51,7 +33,7 @@ // ----- // CHECK-LABEL: func @tensor_extract(%{{.*}}: memref) -> f32 { -func @tensor_extract(%A : tensor {linalg.inplaceable = false}) -> (f32) { +func @tensor_extract(%A : tensor {bufferization.writable = false}) -> (f32) { %c0 = arith.constant 0 : index // CHECK: %[[RES:.*]] = memref.load {{.*}} : memref @@ -65,12 +47,12 @@ // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -/// No linalg.inplaceable flag, must allocate. +/// No bufferization.writable flag, must allocate. // CHECK-LABEL: func @not_inplace( // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref) -> memref { // CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref) -> memref func @not_inplace( - %A : tensor {linalg.inplaceable = false}) + %A : tensor {bufferization.writable = false}) -> tensor { // CHECK: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 @@ -94,7 +76,7 @@ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref) { // CHECK-NO-LAYOUT-MAP-LABEL: func @not_inplace(%{{.*}}: memref) { func @not_inplace( - %A : tensor {linalg.inplaceable = true}) + %A : tensor {bufferization.writable = true}) -> tensor { %f0 = arith.constant 0.0 : f32 @@ -120,7 +102,7 @@ // ----- // CHECK-LABEL: func @not_inplace -func @not_inplace(%A : tensor {linalg.inplaceable = true}) -> tensor { +func @not_inplace(%A : tensor {bufferization.writable = true}) -> tensor { /// Within op multiple uses of %A, must alloc. // CHECK: alloc %r = linalg.matmul ins(%A, %A: tensor, tensor) @@ -132,7 +114,7 @@ // ----- // CHECK-LABEL: func @vec_inplace -func @vec_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) +func @vec_inplace(%A : tensor {bufferization.writable = true}, %vec : vector<4xf32>) -> tensor { %c0 = arith.constant 0 : index @@ -151,7 +133,7 @@ // CHECK-LABEL: func @vec_not_inplace // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref -func @vec_not_inplace(%A : tensor {linalg.inplaceable = true}, %vec : vector<4xf32>) +func @vec_not_inplace(%A : tensor {bufferization.writable = true}, %vec : vector<4xf32>) -> (tensor, tensor) { %c0 = arith.constant 0 : index @@ -182,10 +164,10 @@ // CHECK-SAME: %[[A1:[a-zA-Z0-9]*]]: memref, // CHECK-SAME: %[[t0:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]>, // CHECK-SAME: %[[t1:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> -func @insert_slice_fun(%A0 : tensor {linalg.inplaceable = false}, - %A1 : tensor {linalg.inplaceable = true}, - %t0 : tensor<4xf32> {linalg.inplaceable = false}, - %t1 : tensor<4xf32> {linalg.inplaceable = true}) +func @insert_slice_fun(%A0 : tensor {bufferization.writable = false}, + %A1 : tensor {bufferization.writable = true}, + %t0 : tensor<4xf32> {bufferization.writable = false}, + %t1 : tensor<4xf32> {bufferization.writable = true}) -> (tensor, tensor, tensor, tensor) { // Hoisted allocs. @@ -230,8 +212,8 @@ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> func @insert_slice_fun( - %A : tensor {linalg.inplaceable = true}, - %t : tensor<4xf32> {linalg.inplaceable = false}) + %A : tensor {bufferization.writable = true}, + %t : tensor<4xf32> {bufferization.writable = false}) -> tensor { %f0 = arith.constant 0.0 : f32 @@ -258,8 +240,8 @@ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> func @insert_slice_fun( - %A : tensor {linalg.inplaceable = true}, - %t : tensor<4xf32> {linalg.inplaceable = false}) + %A : tensor {bufferization.writable = true}, + %t : tensor<4xf32> {bufferization.writable = false}) -> tensor { %f0 = arith.constant 0.0 : f32 @@ -286,8 +268,8 @@ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> func @insert_slice_fun_not_inplace( - %A : tensor {linalg.inplaceable = false}, - %t : tensor<4xf32> {linalg.inplaceable = false}) + %A : tensor {bufferization.writable = false}, + %t : tensor<4xf32> {bufferization.writable = false}) -> tensor { // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 128 : i64} : memref @@ -311,8 +293,8 @@ // CHECK-LABEL: func @scf_for_yield_only // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref -func @scf_for_yield_only(%A : tensor {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = true}, +func @scf_for_yield_only(%A : tensor {bufferization.writable = false}, + %B : tensor {bufferization.writable = true}, %lb : index, %ub : index, %step : index) -> (tensor, tensor) { @@ -342,7 +324,7 @@ // just want to make sure that it does not crash. // CHECK-LABEL: func @nested_scf_for -func @nested_scf_for(%A : tensor {linalg.inplaceable = true}, +func @nested_scf_for(%A : tensor {bufferization.writable = true}, %v : vector<5xf32>) -> tensor { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -366,9 +348,9 @@ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$map_1d_dyn]]> func @scf_for_with_tensor.insert_slice( - %A : tensor {linalg.inplaceable = false}, - %B : tensor {linalg.inplaceable = true}, - %C : tensor<4xf32> {linalg.inplaceable = false}, + %A : tensor {bufferization.writable = false}, + %B : tensor {bufferization.writable = true}, + %C : tensor<4xf32> {bufferization.writable = false}, %lb : index, %ub : index, %step : index) -> (tensor, tensor) { @@ -409,7 +391,7 @@ // CHECK-LABEL: func @execute_region_with_conflict( // CHECK-SAME: %[[m1:.*]]: memref {linalg.inplaceable = "true"}) +func @execute_region_with_conflict(%t1 : tensor {bufferization.writable = "true"}) -> (f32, tensor, f32) { %f1 = arith.constant 0.0 : f32 @@ -441,9 +423,9 @@ // CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref<256x192xf32> // CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref<128x192xf32> func @matmul( - %A: tensor<128x256xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %B: tensor<256x192xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, - %C: tensor<128x192xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = true}) + %A: tensor<128x256xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, + %B: tensor<256x192xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = false}, + %C: tensor<128x192xf32> {bufferization.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, bufferization.writable = true}) -> tensor<128x192xf32> { %c0 = arith.constant 0 : index %c256 = arith.constant 256 : index @@ -515,8 +497,8 @@ // CHECK: %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32 // CHECK: memref.copy %[[alloc]], %[[subview]] func @tensor_cast_not_in_place( - %A : tensor {linalg.inplaceable = true}, - %B : tensor {linalg.inplaceable = false}, %idx: index) + %A : tensor {bufferization.writable = true}, + %B : tensor {bufferization.writable = false}, %idx: index) -> (tensor) { %r0 = tensor.cast %A : tensor to tensor<4xf32> @@ -535,7 +517,7 @@ // CHECK-LABEL: func @dominance_violation_bug_1 func @dominance_violation_bug_1( - %A : tensor {linalg.inplaceable = false}, + %A : tensor {bufferization.writable = false}, %idx : index) -> tensor { @@ -555,7 +537,7 @@ // CHECK-LABEL: func @scf_if_inplace( // CHECK-SAME: %[[cond:.*]]: i1, %[[t1:.*]]: memref, %[[v:.*]]: vector func @scf_if_inplace(%cond: i1, - %t1: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, %v: vector<5xf32>, %idx: index) -> tensor { // CHECK: scf.if %[[cond]] { @@ -584,7 +566,7 @@ // CHECK: vector.transfer_write // CHECK: } // CHECK: } -func @scf_if_inside_scf_for(%t1: tensor {linalg.inplaceable = true}, +func @scf_if_inside_scf_for(%t1: tensor {bufferization.writable = true}, %v: vector<5xf32>, %idx: index, %cond: i1) -> tensor { %c0 = arith.constant 0 : index @@ -608,8 +590,8 @@ // CHECK-SAME: %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}> func @scf_if_non_equiv_yields( %b : i1, - %A : tensor<4xf32> {linalg.inplaceable = false}, - %B : tensor<4xf32> {linalg.inplaceable = false}) + %A : tensor<4xf32> {bufferization.writable = false}, + %B : tensor<4xf32> {bufferization.writable = false}) -> tensor<4xf32> { // CHECK: %[[r:.*]] = arith.select %[[cond]], %[[A]], %[[B]] @@ -626,7 +608,7 @@ // CHECK-LABEL: func @insert_op // CHECK-SAME: %[[t1:.*]]: memref, %[[s:.*]]: f32, %[[i:.*]]: index -func @insert_op(%t1 : tensor {linalg.inplaceable = true}, +func @insert_op(%t1 : tensor {bufferization.writable = true}, %s : f32, %i : index) -> tensor { // CHECK: memref.store %[[s]], %[[t1]][%[[i]]] %0 = tensor.insert %s into %t1[%i] : tensor @@ -637,9 +619,9 @@ // ----- func @gather_like( - %arg0 : tensor {linalg.inplaceable = false}, - %arg1 : tensor {linalg.inplaceable = false}, - %arg2 : tensor {linalg.inplaceable = true}) -> tensor { + %arg0 : tensor {bufferization.writable = false}, + %arg1 : tensor {bufferization.writable = false}, + %arg2 : tensor {bufferization.writable = true}) -> tensor { %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], @@ -669,9 +651,9 @@ // CHECK-LABEL: func @linalg_op_bufferizes_inplace_with_input // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref, %[[t3:.*]]: memref func @linalg_op_bufferizes_inplace_with_input( - %t1: tensor {linalg.inplaceable = true}, - %t2: tensor {linalg.inplaceable = true}, - %t3: tensor {linalg.inplaceable = true}, + %t1: tensor {bufferization.writable = true}, + %t2: tensor {bufferization.writable = true}, + %t3: tensor {bufferization.writable = true}, %s1: index, %s2: index, %cst: f32) -> tensor { // CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[t3]] : {{.*}}) %r = linalg.generic { @@ -701,7 +683,7 @@ // CHECK-LABEL: func @op_is_reading_but_following_ops_are_not // CHECK-SAME: %[[t0:.*]]: memref {linalg.inplaceable = false}, + %t0 : tensor {bufferization.writable = false}, %cst : f32) -> tensor { @@ -753,8 +735,8 @@ // CHECK-LABEL: func @write_to_select_op_source // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref func @write_to_select_op_source( - %t1 : tensor {linalg.inplaceable = true}, - %t2 : tensor {linalg.inplaceable = true}, + %t1 : tensor {bufferization.writable = true}, + %t2 : tensor {bufferization.writable = true}, %c : i1) -> (tensor, tensor) { @@ -775,8 +757,8 @@ // CHECK-LABEL: func @write_after_select_read_one // CHECK-SAME: %[[t1:.*]]: memref, %[[t2:.*]]: memref func @write_after_select_read_one( - %t1 : tensor {linalg.inplaceable = true}, - %t2 : tensor {linalg.inplaceable = true}, + %t1 : tensor {bufferization.writable = true}, + %t2 : tensor {bufferization.writable = true}, %c : i1) -> (f32, tensor) { @@ -916,7 +898,7 @@ // CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref func @scf_for_swapping_yields( - %A : tensor, %B : tensor {linalg.inplaceable = true}, + %A : tensor, %B : tensor {bufferization.writable = true}, %C : tensor<4xf32>, %lb : index, %ub : index, %step : index) -> (f32, f32) { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7043,7 +7043,6 @@ ":LinalgStructuredOpsIncGen", ":MathDialect", ":MemRefDialect", - ":ModuleBufferization", ":Pass", ":SCFDialect", ":SCFTransforms", @@ -7065,25 +7064,6 @@ ], ) -cc_library( - name = "ModuleBufferization", - srcs = [ - "lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp", - ], - hdrs = [ - "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h", - ], - includes = ["include"], - deps = [ - ":BufferizationDialect", - ":BufferizationTransforms", - ":FuncDialect", - ":IR", - ":MemRefDialect", - "//llvm:Support", - ], -) - cc_library( name = "TilingInterface", srcs = ["lib/Interfaces/TilingInterface.cpp"],