diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -142,6 +142,12 @@ return base->getType() == TypeID::get(); } + /// Return a reference to the BufferizationOptions. + const OneShotBufferizationOptions &getOptions() const { + return static_cast( + AnalysisState::getOptions()); + } + /// Return a reference to the BufferizationAliasInfo. BufferizationAliasInfo &getAliasInfo() { return aliasInfo; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -6,28 +6,27 @@ // //===----------------------------------------------------------------------===// // -// One-Shot Analysis analyzes function bodies. Function boundaries (FuncOp -// bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. -// ModuleBufferization.cpp is an extension of One-Shot Analysis for simple -// call graphs. +// One-Shot Analysis analyzes function bodies. By default, function boundaries +// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops. +// OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for +// simple call graphs without loops. // -// One-Shot Bufferize consists of two phases. +// One-Shot Bufferize consists of three phases. // -// 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without -// inserting buffer copies. The analysis queries op bufferization semantics -// via `BufferizableOpInterface`. -// 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This -// function does not generate buffer copies for OpResults that were decided -// to bufferize inplace during the analysis phase. +// 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e., +// without inserting buffer copies. The analysis queries op bufferization +// semantics via `BufferizableOpInterface`. +// 2. Insert copies for OpOperands that were decided to bufferize out-of-place +// in tensor land during `TensorCopyInsertion`. +// 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`. // -// This file contains only the analysis. The actual bufferization is implemented -// via `bufferizeOp` (Bufferize.h). For convenience, this file also contains a -// helper function `runOneShotBufferize` that analyzes an op (and its nested -// ops) and then bufferizes it. +// This file contains only the analysis. For convenience, this file also +// contains a helper function `runOneShotBufferize` that analyzes an op (and its +// nested ops) and then bufferizes it. // // Inplace bufferization decisions are passed from the analysis to the -// bufferization phase via `AnalysisState` and `BufferizationAliasInfo`. -// They can be printed for debugging purposes with `testAnalysisOnly`. +// `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for +// debugging purposes with `testAnalysisOnly`. // // Ops that do not implement `BufferizableOpInterface` can be analyzed but are // treated conservatively. E.g., the analysis has to assume that their tensor @@ -70,33 +69,30 @@ //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. -// These are for testing and debugging only. Bufferization information is -// stored in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR -// is annotated with the results of the analysis (copied from -// BufferizationAliasInfo), so that they can be checked in tests. +// These are for testing and debugging only. Bufferization information is stored +// in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR is +// annotated with the results of the analysis, so that they can be checked in +// tests. //===----------------------------------------------------------------------===// -/// Attribute marker to specify op results that can be bufferized inPlace. -constexpr StringLiteral kInPlaceResultsAttrName = "__inplace_operands_attr__"; +/// Attribute marker to specify op operands that bufferize in-place. +constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__"; /// Mark whether OpOperand will be bufferized inplace. static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { Operation *op = opOperand.getOwner(); - auto attr = - op->getAttr(kInPlaceResultsAttrName).dyn_cast_or_null(); SmallVector inPlaceVector; - if (attr) { - inPlaceVector = SmallVector( - llvm::to_vector<4>(attr.getAsValueRange())); + if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) { + inPlaceVector = SmallVector(llvm::to_vector<4>( + attr.cast().getAsValueRange())); } else { inPlaceVector = SmallVector(op->getNumOperands(), "none"); for (OpOperand &opOperand : op->getOpOperands()) if (opOperand.get().getType().isa()) inPlaceVector[opOperand.getOperandNumber()] = "false"; } - inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; - op->setAttr(kInPlaceResultsAttrName, + op->setAttr(kInPlaceOperandsAttrName, OpBuilder(op).getStrArrayAttr(inPlaceVector)); } @@ -937,8 +933,7 @@ }; OneShotBufferizationOptions::AnalysisHeuristic heuristic = - static_cast(state.getOptions()) - .analysisHeuristic; + state.getOptions().analysisHeuristic; if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) { // Default: Walk ops in reverse for better interference analysis. for (Operation *op : reverse(ops)) @@ -1057,10 +1052,11 @@ static void annotateOpsWithBufferizationMarkers(Operation *op, const BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { - op->walk([&](Operation *op) { - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) - for (OpOperand &opOperand : op->getOpOperands()) + const BufferizationOptions &options) { + // Add __inplace_operands_attr__. + op->walk([&](BufferizableOpInterface bufferizableOp) { + if (options.isOpAllowed(bufferizableOp.getOperation())) + for (OpOperand &opOperand : bufferizableOp->getOpOperands()) if (opOperand.get().getType().isa()) setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand)); }); @@ -1140,8 +1136,7 @@ OneShotAnalysisState &state) { DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.getAliasInfo(); - const auto &options = - static_cast(state.getOptions()); + const OneShotBufferizationOptions &options = state.getOptions(); if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo))) return failure(); @@ -1174,7 +1169,7 @@ // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) - annotateOpsWithBufferizationMarkers(op, aliasInfo, state); + annotateOpsWithBufferizationMarkers(op, aliasInfo, options); return success(!failedAnalysis); } @@ -1186,7 +1181,6 @@ "invalid combination of bufferization flags"); if (!options.copyBeforeWrite) { // If a buffer is copied before every write, no analysis is needed. - OneShotAnalysisState state(op, options); if (failed(insertTensorCopies(op, options))) return failure(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -335,9 +335,7 @@ LogicalResult mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state) { - OneShotBufferizationOptions options = - static_cast(state.getOptions()); - assert(options.bufferizeFunctionBoundaries && + assert(state.getOptions().bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); BufferizationAliasInfo &aliasInfo = state.getAliasInfo();