diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -172,13 +172,6 @@ /// Apply `fun` to all aliases of `v`. void applyOnAliases(Value v, function_ref fun) const; - // TODO: Move these out of BufferizationAliasInfo. - /// Return true if the value is known to bufferize to writable memory. - bool bufferizesToWritableMemory(Value v) const; - - /// Specify that the value is known to bufferize to writable memory. - void setBufferizesToWritableMemory(Value v); - /// Mark a value as in-place bufferized. void markInPlace(OpResult v) { inplaceBufferized.insert(v); } @@ -200,9 +193,6 @@ /// Check that aliasInfo for `v` exists and return a reference to it. EquivalenceClassRangeType getAliases(Value v) const; - /// Set of tensors that are known to bufferize to writable memory. - llvm::DenseSet bufferizeToWritableMemory; - /// Set of all OpResults that were decided to bufferize in-place. llvm::DenseSet inplaceBufferized; @@ -429,7 +419,9 @@ return BufferRelation::None; } - bool isWritable(Operation *op, Value value) const { return false; } + bool isWritable(Operation *op, Value value, BufferizationState &state) const { + return false; + } LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -226,7 +226,8 @@ }], /*retType=*/"bool", /*methodName=*/"isWritable", - /*args=*/(ins "Value":$value), + /*args=*/(ins "Value":$value, + "BufferizationState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ return value.isa(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -42,7 +42,7 @@ return success(); } - bool isWritable(Operation *op, Value value) const { + bool isWritable(Operation *op, Value value, BufferizationState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. assert(value.isa()); return false; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -130,15 +130,6 @@ equivalentInfo.unionSets(newValue, alias); } -bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const { - return bufferizeToWritableMemory.count(v) > 0; -} - -/// Specify that the value is known to bufferize to writable memory. -void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) { - bufferizeToWritableMemory.insert(v); -} - /// Return `true` if a value was marked as in-place bufferized. bool BufferizationAliasInfo::isInPlace(OpResult opResult) const { return inplaceBufferized.contains(opResult); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -71,7 +71,7 @@ return success(); } - bool isWritable(Operation *op, Value value) const { + bool isWritable(Operation *op, Value value, BufferizationState &state) const { // It is unknown whether the MemRef operand is writable or not. return false; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -171,15 +171,6 @@ OpBuilder(op).getStrArrayAttr(inPlaceVector)); } -/// Set the attribute that triggers inplace bufferization on a FuncOp argument -/// `bbArg`. -static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { - auto funcOp = cast(bbArg.getOwner()->getParentOp()); - funcOp.setArgAttr(bbArg.getArgNumber(), - BufferizableOpInterface::kInplaceableAttrName, - BoolAttr::get(bbArg.getContext(), inPlace)); -} - //===----------------------------------------------------------------------===// // Printing helpers. //===----------------------------------------------------------------------===// @@ -258,25 +249,22 @@ /// Return true if, under current bufferization decisions, the buffer of `value` /// is not writable. static bool aliasesNonWritableBuffer(Value value, - const BufferizationAliasInfo &aliasInfo) { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) { LDBG("WRITABILITY ANALYSIS FOR " << printValueInfo(value) << "\n"); bool foundNonWritableBuffer = false; aliasInfo.applyOnAliases(value, [&](Value v) { - // Some values are known to be writable. - if (aliasInfo.bufferizesToWritableMemory(v)) - return; - // Query BufferizableOpInterface to see if the OpResult is writable. // TODO: Out-of-place bufferized OpResult could be considered writable. if (auto bufferizableOp = v.getDefiningOp()) - if (bufferizableOp && bufferizableOp.isWritable(v)) + if (bufferizableOp && bufferizableOp.isWritable(v, state)) return; // Query BufferizableOpInterface to see if the BlockArgument is writable. if (auto bbArg = v.dyn_cast()) if (auto bufferizableOp = dyn_cast( bbArg.getOwner()->getParentOp())) - if (bufferizableOp.isWritable(bbArg)) + if (bufferizableOp.isWritable(bbArg, state)) return; foundNonWritableBuffer = true; @@ -515,7 +503,8 @@ /// a write to a non-writable buffer. static bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult, - const BufferizationAliasInfo &aliasInfo) { + const BufferizationAliasInfo &aliasInfo, + BufferizationState &state) { #ifndef NDEBUG SmallVector opOperands = getAliasingOpOperand(opResult); assert(llvm::find(opOperands, &opOperand) != opOperands.end() && @@ -525,9 +514,10 @@ // Certain buffers are not writeable: // 1. A function bbArg that is not inplaceable or // 2. A constant op. - assert(!aliasesNonWritableBuffer(opResult, aliasInfo) && + assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) && "expected that opResult does not alias non-writable buffer"); - bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo); + bool nonWritable = + aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state); if (!nonWritable) return false; @@ -547,10 +537,9 @@ //===----------------------------------------------------------------------===// /// Determine if `operand` can be bufferized in-place with `result`. -static LogicalResult -bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result, - BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo) { +static LogicalResult bufferizableInPlaceAnalysisImpl( + OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo, + BufferizationState &state, const DominanceInfo &domInfo) { #ifndef NDEBUG SmallVector opOperands = getAliasingOpOperand(result); assert(llvm::find(opOperands, &operand) != opOperands.end() && @@ -565,7 +554,7 @@ << printValueInfo(result) << '\n'); bool foundInterference = - wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo) || + wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) || wouldCreateReadAfterWriteInterference(operand, result, domInfo, aliasInfo); @@ -599,6 +588,7 @@ /// RaW dependence violations. static LogicalResult inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, + BufferizationState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { if (analysisFuzzerSeed) { @@ -615,8 +605,8 @@ if (opOperand.get().getType().isa()) if (auto bufferizableOp = dyn_cast(op)) if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand)) - if (failed(bufferizableInPlaceAnalysisImpl(opOperand, opResult, - aliasInfo, domInfo))) + if (failed(bufferizableInPlaceAnalysisImpl( + opOperand, opResult, aliasInfo, state, domInfo))) return failure(); return success(); @@ -625,6 +615,7 @@ /// Analyze all ops that are contained in `op`. static LogicalResult inPlaceAnalysis(Operation *op, BufferizationAliasInfo &aliasInfo, + BufferizationState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { // Collect ops so we can build our own reverse traversal. @@ -637,7 +628,7 @@ ops.push_back(op); }); - return inPlaceAnalysis(ops, aliasInfo, domInfo, analysisFuzzerSeed); + return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed); } /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops. @@ -712,15 +703,9 @@ annotateOpsWithBufferizationMarkers(Operation *op, const BufferizationAliasInfo &aliasInfo) { op->walk([&](Operation *op) { - for (OpResult opResult : op->getResults()) { + for (OpResult opResult : op->getResults()) if (opResult.getType().isa()) setInPlaceOpResult(opResult, aliasInfo.isInPlace(opResult)); - if (auto funcOp = dyn_cast(op)) - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) - setInPlaceFuncArgument(bbArg, - aliasInfo.bufferizesToWritableMemory(bbArg)); - } }); } @@ -739,8 +724,8 @@ // If the analysis fails, just return. Operation *op = funcOp.getOperation(); - if (failed( - inPlaceAnalysis(op, aliasInfo, domInfo, options.analysisFuzzerSeed))) + if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo, + options.analysisFuzzerSeed))) return failure(); equivalenceAnalysis(op, aliasInfo); @@ -750,7 +735,7 @@ if (failed(step->run(funcOp, state, newOps))) return failure(); // Analyze ops that were created by the PostAnalysisStep. - if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) + if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo))) return failure(); equivalenceAnalysis(newOps, aliasInfo); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -193,7 +193,7 @@ return BufferRelation::Equivalent; } - bool isWritable(Operation *op, Value value) const { + bool isWritable(Operation *op, Value value, BufferizationState &state) const { // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -36,6 +36,10 @@ /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg /// indices. DenseMap> equivalentFuncArgs; + + SmallVector orderedFuncOps; + + DenseMap> callerMap; }; } // namespace @@ -689,6 +693,32 @@ return comprehensive_bufferize::bufferize(&funcOp.body(), state); } + /// Return `true` if the given function argument is writable. + bool isWritable(Operation *op, Value value, BufferizationState &state) const { + auto funcOp = cast(op); + BlockArgument bbArg = value.dyn_cast(); + assert(bbArg && "expected BlockArgument"); + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + + // In a first approximation: + // ========================= + // If the function is called, we can allocate on the caller side which lets + // us force inplace arguments at function boundaries. + // TODO: do not rely on this behavior. + if (moduleState.callerMap.find(funcOp) != moduleState.callerMap.end()) + return true; + + // Set the function arguments marked with inplaceable to be known as + // bufferizing to a writeable memory. + BoolAttr inplaceAttr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName); + if (inplaceAttr && inplaceAttr.getValue()) + return true; + + // All other function arguments are not writable. + return false; + } + bool isAllocationHoistingBarrier(Operation *op) const { return true; } }; @@ -704,46 +734,44 @@ registry.addOpInterface(); } +/// Set the attribute that triggers inplace bufferization on a FuncOp argument +/// `bbArg`. +static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) { + auto funcOp = cast(bbArg.getOwner()->getParentOp()); + funcOp.setArgAttr(bbArg.getArgNumber(), + BufferizableOpInterface::kInplaceableAttrName, + BoolAttr::get(bbArg.getContext(), inPlace)); +} + +/// Annotate the IR with the result of the analysis. For testing/debugging only. +static void annotateOpsWithBufferizationMarkers(FuncOp funcOp, + BufferizationState &state) { + auto bufferizableOp = cast(funcOp.getOperation()); + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); +} + LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( ModuleOp moduleOp, const BufferizationOptions &options) { - SmallVector orderedFuncOps; - DenseMap> callerMap; - if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) - return failure(); - BufferizationState state(moduleOp, options); ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.aliasInfo; + if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps, + moduleState.callerMap))) + return failure(); + // Interestingly, all function args that are not visible outside of a module // can be fully bufferized inplace by guaranteeing the CallOp is bufferized // inplace. Therefore, we just bufferize funcOp as if none of its results were // inplaceable, detect which operands are cloned internally and decide what to // do at call sites. - for (FuncOp funcOp : orderedFuncOps) { + for (FuncOp funcOp : moduleState.orderedFuncOps) { // No body => no analysis. if (funcOp.body().empty()) continue; - // In a first approximation: - // ========================= - // If the function is called, we can allocate on the caller side which lets - // us force inplace arguments at function boundaries. - // TODO: do not rely on this behavior. - if (callerMap.find(funcOp) != callerMap.end()) - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) - aliasInfo.setBufferizesToWritableMemory(bbArg); - - // Set the function arguments marked with inplaceable to be known as - // bufferizing to a writeable memory. - for (BlockArgument bbArg : funcOp.getArguments()) { - BoolAttr inplaceAttr = funcOp.getArgAttrOfType( - bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName); - if (inplaceAttr && inplaceAttr.getValue()) - aliasInfo.setBufferizesToWritableMemory(bbArg); - } - // Register extra post analysis steps. These cannot be stored in `options` // because `options` is immutable. PostAnalysisStepList extraSteps; @@ -755,12 +783,16 @@ // Analyze and bufferize funcOp. if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps))) return failure(); + + // Add annotations to function arguments. + if (options.testAnalysisOnly) + annotateOpsWithBufferizationMarkers(funcOp, state); } if (options.testAnalysisOnly) return success(); - for (FuncOp funcOp : orderedFuncOps) { + for (FuncOp funcOp : moduleState.orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. if (failed(bufferizeFuncOpBoundary(funcOp, state))) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -204,7 +204,7 @@ return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None; } - bool isWritable(Operation *op, Value value) const { + bool isWritable(Operation *op, Value value, BufferizationState &state) const { // Interestingly, scf::ForOp's bbArg can **always** be viewed // inplace from the perspective of ops nested under: // 1. Either the matching iter operand is not bufferized inplace and an