diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -409,6 +409,9 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0; + /// Return true if `v1` and `v2` may bufferize to aliasing buffers. + virtual bool areAliasingBufferizedValues(Value v1, Value v2) const = 0; + /// Return `true` if the given tensor has undefined contents. virtual bool hasUndefinedContents(OpOperand *opOperand) const = 0; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -19,30 +19,10 @@ class BufferizationAliasInfo; class OneShotAnalysisState; -/// PostAnalysisStepFns can be registered with `BufferizationOptions` and are -/// executed after the analysis, but before bufferization. They can be used to -/// implement custom dialect-specific optimizations. They may modify the IR, but -/// must keep `aliasInfo` consistent. Newly created operations and operations -/// that should be re-analyzed must be added to `newOps`. -using PostAnalysisStepFn = std::function &)>; - -using PostAnalysisStepList = SmallVector; - /// Options for analysis-enabled bufferization. struct OneShotBufferizationOptions : public BufferizationOptions { OneShotBufferizationOptions() = default; - /// Register a "post analysis" step. Such steps are executed after the - /// analysis, but before bufferization. - void addPostAnalysisStep(PostAnalysisStepFn fn) { - postAnalysisSteps.push_back(fn); - } - - /// Registered post analysis steps. - PostAnalysisStepList postAnalysisSteps; - /// Specifies whether returning newly allocated memrefs should be allowed. /// Otherwise, a pass failure is triggered. bool allowReturnAllocs = false; @@ -165,6 +145,9 @@ /// Return true if `v1` and `v2` bufferize to equivalent buffers. bool areEquivalentBufferizedValues(Value v1, Value v2) const override; + /// Return true if `v1` and `v2` may bufferize to aliasing buffers. + bool areAliasingBufferizedValues(Value v1, Value v2) const override; + /// Return `true` if the given tensor has undefined contents. bool hasUndefinedContents(OpOperand *opOperand) const override; @@ -180,6 +163,10 @@ /// `yieldedTensors`. Also include all aliasing tensors in the same block. void gatherYieldedTensors(Operation *op); + /// Return true if the buffer of the given tensor value is written to. Must + /// not be called for values inside not yet analyzed functions. + bool isValueWritten(Value value) const; + private: /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal /// functions and `runOneShotBufferize` may access this object. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -503,6 +503,13 @@ return false; } + /// Return true if `v1` and `v2` may bufferize to aliasing buffers. + bool areAliasingBufferizedValues(Value v1, Value v2) const override { + // There is no analysis, so we do not know if the values are equivalent. The + // conservative answer is "true". + return true; + } + /// Return `true` if the given tensor has undefined contents. bool hasUndefinedContents(OpOperand *opOperand) const override { // There is no analysis, so the conservative answer is "false". diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -216,6 +216,11 @@ return aliasInfo.areEquivalentBufferizedValues(v1, v2); } +bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1, + Value v2) const { + return aliasInfo.areAliasingBufferizedValues(v1, v2); +} + // Gather yielded tensors in `yieldedTensors` by querying all aliases. This is // to ensure that such information is available during bufferization time. // Alias information can no longer be queried through BufferizationAliasInfo @@ -290,6 +295,16 @@ return yieldedTensors.contains(tensor); } +bool OneShotAnalysisState::isValueWritten(Value value) const { + bool isWritten = false; + aliasInfo.applyOnAliases(value, [&](Value val) { + for (OpOperand &use : val.getUses()) + if (isInPlace(use) && bufferizesToMemoryWrite(use)) + isWritten = true; + }); + return isWritten; +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -935,16 +950,6 @@ return failure(); equivalenceAnalysis(op, aliasInfo, state); - for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) { - SmallVector newOps; - if (failed(fn(op, state, aliasInfo, newOps))) - return failure(); - // Analyze ops that were created by the PostAnalysisStepFn. - if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) - return failure(); - equivalenceAnalysis(newOps, aliasInfo, state); - } - bool failedAnalysis = false; if (!options.allowReturnAllocs) { SmallVector newOps; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -16,7 +16,7 @@ // respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is -// gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`. +// gathered and stored in `FuncAnalysisState`. // // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs // for @@ -155,14 +155,11 @@ /// Store function BlockArguments that are equivalent to/aliasing a returned /// value in FuncAnalysisState. -static LogicalResult -aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { +static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, + OneShotAnalysisState &state) { FuncAnalysisState &funcState = getFuncAnalysisState(state); // Support only single return-terminated block in the function. - auto funcOp = cast(op); func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); @@ -172,12 +169,12 @@ if (bbArg.getType().isa()) { int64_t returnIdx = returnVal.getOperandNumber(); int64_t bbArgIdx = bbArg.getArgNumber(); - if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { + if (state.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; if (state.getOptions().testAnalysisOnly) annotateEquivalentReturnBbArg(returnVal, bbArg); } - if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { + if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) { funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); } @@ -186,35 +183,6 @@ return success(); } -/// Return true if the buffer of the given tensor value is written to. Must not -/// be called for values inside not yet analyzed functions. (Post-analysis -/// steps do not have to be run yet, i.e., "in progress" is also OK.) -static bool isValueWritten(Value value, const AnalysisState &state, - const BufferizationAliasInfo &aliasInfo) { -#ifndef NDEBUG - assert(value.getType().isa() && "expected TensorType"); - func::FuncOp funcOp; - if (auto bbArg = value.dyn_cast()) { - Operation *owner = bbArg.getOwner()->getParentOp(); - funcOp = isa(owner) ? cast(owner) - : owner->getParentOfType(); - } else { - funcOp = value.getDefiningOp()->getParentOfType(); - } - assert(getFuncOpAnalysisState(state, funcOp) != - FuncOpAnalysisState::NotAnalyzed && - "FuncOp must be fully analyzed or analysis in progress"); -#endif // NDEBUG - - bool isWritten = false; - aliasInfo.applyOnAliases(value, [&](Value val) { - for (OpOperand &use : val.getUses()) - if (state.isInPlace(use) && state.bufferizesToMemoryWrite(use)) - isWritten = true; - }); - return isWritten; -} - static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg, bool isRead, bool isWritten) { OpBuilder b(funcOp.getContext()); @@ -231,15 +199,12 @@ funcOp.setArgAttr(bbArg.getArgNumber(), "bufferization.access", accessType); } -/// Determine which FuncOp bbArgs are read and which are written. If this -/// PostAnalysisStepFn is run on a function with unknown ops, it will -/// conservatively assume that such ops bufferize to a read + write. -static LogicalResult -funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { +/// Determine which FuncOp bbArgs are read and which are written. When run on a +/// function with unknown ops, we conservatively assume that such ops bufferize +/// to a read + write. +static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp, + OneShotAnalysisState &state) { FuncAnalysisState &funcState = getFuncAnalysisState(state); - auto funcOp = cast(op); // If the function has no body, conservatively assume that all args are // read + written. @@ -256,7 +221,7 @@ if (!bbArg.getType().isa()) continue; bool isRead = state.isValueRead(bbArg); - bool isWritten = isValueWritten(bbArg, state, aliasInfo); + bool isWritten = state.isValueWritten(bbArg); if (state.getOptions().testAnalysisOnly) annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); if (isRead) @@ -434,10 +399,6 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); - // Collect bbArg/return value information after the analysis. - options.addPostAnalysisStep(aliasingFuncOpBBArgsAnalysis); - options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); - // Analyze ops. for (func::FuncOp funcOp : orderedFuncOps) { // No body => no analysis. @@ -454,6 +415,11 @@ if (failed(analyzeOp(funcOp, analysisState))) return failure(); + // Run some extra function analyses. + if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, analysisState)) || + failed(funcOpBbArgReadWriteAnalysis(funcOp, analysisState))) + return failure(); + // Mark op as fully analyzed. funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed;