diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -16,10 +16,9 @@ // Composability with extensible set of ops is not a first-class concern. // // Bufferization occurs by: -// a. performing an inPlace analysis `inPlaceAnalysisFuncOpBody` -// which marks each operation within the function with the -// `kInPlaceResultsAttrName` attribute. -// b. traversing each operation in the function and rewriting it in +// a. performing an inPlace analysis `inPlaceAnalysis` which marks each +// operation within the op with the `kInPlaceResultsAttrName` attribute. +// b. traversing each operation in the op and rewriting it in // buffer form and keeping a BlockAndValueMapping mapping of the // rewrites. New allocations are introduced during this step. // TODO: Allocation + depending op hoisting to outermost enclosing @@ -543,37 +542,6 @@ return true; } -//===----------------------------------------------------------------------===// -// Bufferization as simple BlockAndValueMapping rewrites. -//===----------------------------------------------------------------------===// - -/// FuncOp always creates TensorToMemRef ops. -static LogicalResult bufferizeFuncOp(FuncOp funcOp, BufferizationState &state) { - // Take a guard before anything else. - OpBuilder b(funcOp->getContext()); - b.setInsertionPointToStart(&funcOp.body().front()); - - // Create BufferCastOps for function args. - for (auto bbArg : funcOp.getArguments()) { - auto tensorType = bbArg.getType().dyn_cast(); - if (!tensorType) - continue; - auto rankedTensorType = tensorType.dyn_cast(); - // Cast the tensor to the most dynamic buffer possible. Further - // canonicalizations will clean up. - Type memRefType = rankedTensorType - ? getDynamicMemRefType(rankedTensorType) - : getContiguousOrUnrankedMemRefType(tensorType); - Value bufferCast = - b.create(funcOp.getLoc(), memRefType, bbArg); - state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); - state.mapBuffer(bbArg, bufferCast); - } - - // Bufferize function body. - return bufferize(&funcOp.body(), state); -} - //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -654,42 +622,6 @@ return success(); } -/// Analyze the `funcOp` body to determine which OpResults are inplaceable. -static LogicalResult -inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - const DominanceInfo &domInfo, - unsigned analysisFuzzerSeed = 0) { - LLVM_DEBUG(llvm::dbgs() << "\n\n"); - LDBG("Begin InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); - assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() && - "expected a funcOp definition with a body"); - - // Collect ops so we can build our own reverse traversal. - SmallVector ops; - funcOp.walk([&](Operation *op) { - // No tensors => no buffers. - if (none_of(op->getOperandTypes(), isaTensor) && - none_of(op->getResultTypes(), isaTensor)) - return; - ops.push_back(op); - }); - - // Set the function arguments marked with inplaceable to be known as - // bufferizing to a writeable memory. - for (BlockArgument bbArg : funcOp.getArguments()) { - BoolAttr inplaceAttr = funcOp.getArgAttrOfType( - bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName); - if (inplaceAttr && inplaceAttr.getValue()) - aliasInfo.setBufferizesToWritableMemory(bbArg); - } - - LogicalResult res = - inPlaceAnalysis(ops, aliasInfo, domInfo, analysisFuzzerSeed); - LDBG("End InPlaceAnalysisFuncOpInternals:\n" << funcOp << '\n'); - - return res; -} - /// Assert that the current bufferization decisions are consistent. static LogicalResult checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo, @@ -753,9 +685,19 @@ if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo))) return failure(); + // Collect ops so we can build our own reverse traversal. + SmallVector ops; + funcOp.walk([&](Operation *op) { + // No tensors => no buffers. + if (none_of(op->getOperandTypes(), isaTensor) && + none_of(op->getResultTypes(), isaTensor)) + return; + ops.push_back(op); + }); + // If the analysis fails, just return. - if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo, - options.analysisFuzzerSeed))) + if (failed( + inPlaceAnalysis(ops, aliasInfo, domInfo, options.analysisFuzzerSeed))) return failure(); for (const std::unique_ptr &step : @@ -775,7 +717,11 @@ } // Bufferize all ops in funcOp. - if (failed(bufferizeFuncOp(funcOp, state))) + OpBuilder b(funcOp.getContext()); + auto bufferizableOp = + dyn_cast(funcOp.getOperation()); + assert(bufferizableOp && "must use ModuleBufferization"); + if (failed(bufferizableOp.bufferize(b, state))) return failure(); // Erase all obsolete ops. diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -629,6 +629,40 @@ } }; +struct FuncOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto funcOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(&funcOp.body().front()); + + // Create BufferCastOps for function args. + for (auto bbArg : funcOp.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast(); + if (!tensorType) + continue; + auto rankedTensorType = tensorType.dyn_cast(); + // Cast the tensor to the most dynamic buffer possible. Further + // canonicalizations will clean up. + Type memRefType = rankedTensorType + ? getDynamicMemRefType(rankedTensorType) + : getContiguousOrUnrankedMemRefType(tensorType); + Value bufferCast = b.create(funcOp.getLoc(), + memRefType, bbArg); + state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); + state.mapBuffer(bbArg, bufferCast); + } + + // Bufferize function body. + return comprehensive_bufferize::bufferize(&funcOp.body(), state); + } + + bool isAllocationHoistingBarrier(Operation *op) const { return true; } +}; + } // namespace std_ext } // namespace comprehensive_bufferize } // namespace linalg @@ -638,7 +672,7 @@ registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); registry.addOpInterface(); - registry.addOpInterface>(); + registry.addOpInterface(); } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( @@ -671,6 +705,15 @@ if (bbArg.getType().isa()) aliasInfo.setBufferizesToWritableMemory(bbArg); + // Set the function arguments marked with inplaceable to be known as + // bufferizing to a writeable memory. + for (BlockArgument bbArg : funcOp.getArguments()) { + BoolAttr inplaceAttr = funcOp.getArgAttrOfType( + bbArg.getArgNumber(), BufferizableOpInterface::kInplaceableAttrName); + if (inplaceAttr && inplaceAttr.getValue()) + aliasInfo.setBufferizesToWritableMemory(bbArg); + } + // Analyze and bufferize funcOp. if (failed(runComprehensiveBufferize(funcOp, options, state))) return failure();