diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -297,10 +297,16 @@ /// bufferization is necessary. Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state); +/// Bufferize all ops in the given region. +LogicalResult bufferize(Region *region, BufferizationState &state); + +/// Bufferize all ops in the given block. +LogicalResult bufferize(Block *block, BufferizationState &state); + /// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this /// function returns immediately. Otherwise, it calls the `bufferize` interface /// method of `BufferizableOpInterface`. -LogicalResult bufferizeOp(Operation *op, BufferizationState &state); +LogicalResult bufferize(Operation *op, BufferizationState &state); /// PostAnalysisSteps can be registered with `BufferizationOptions` and are /// executed after the analysis, but before bufferization. They can be used diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td @@ -163,8 +163,13 @@ InterfaceMethod< /*desc=*/[{ Bufferize this op, i.e., rewrite it into a memref-based equivalent. - `bvm` maps tensor values to memref values and this method should map - tensor results to memref results after creating/modifying ops. + Tensor values should be mapped to buffer values using `state`. + + Implementations are required to required to bufferize nested ops + before returning. Otherwise, nested ops will not be bufferized. + + This method will never be called on ops that do not have at least one + tensor operand or result. }], /*retType=*/"LogicalResult", /*methodName=*/"bufferize", diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -392,8 +392,26 @@ } LogicalResult -mlir::linalg::comprehensive_bufferize::bufferizeOp(Operation *op, - BufferizationState &state) { +mlir::linalg::comprehensive_bufferize::bufferize(Region *region, + BufferizationState &state) { + for (Block &block : *region) + if (failed(bufferize(&block, state))) + return failure(); + return success(); +} + +LogicalResult +mlir::linalg::comprehensive_bufferize::bufferize(Block *block, + BufferizationState &state) { + for (Operation &op : *block) + if (failed(bufferize(&op, state))) + return failure(); + return success(); +} + +LogicalResult +mlir::linalg::comprehensive_bufferize::bufferize(Operation *op, + BufferizationState &state) { OpBuilder b(op->getContext()); // Skip BufferCast and TensorLoad ops. @@ -404,15 +422,22 @@ auto isaTensor = [](Type t) { return t.isa(); }; bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); - if (!hasTensorResult && !hasTensorOperand) + + // No tensor results or operands: Simply bufferize all nested ops. + if (!hasTensorResult && !hasTensorOperand) { + for (Region ®ion : op->getRegions()) + if (failed(bufferize(®ion, state))) + return failure(); return success(); + } - // Bufferize using `BufferizableOpInterface`. + // Bufferize using `BufferizableOpInterface`. Interface implementations are + // responsible for bufferizing nested ops. b.setInsertionPoint(op); if (auto bufferizableOp = dyn_cast(op)) return bufferizableOp.bufferize(b, state); - // Other op with tensors. No bufferization method specified. + // Emit error if tensor op is not bufferizable. return op->emitError() << "unsupported op with tensors"; } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -784,11 +784,12 @@ //===----------------------------------------------------------------------===// /// FuncOp always creates TensorToMemRef ops. -static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, - BufferizationState &state) { +static LogicalResult bufferizeFuncOp(FuncOp funcOp, BufferizationState &state) { // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); + OpBuilder b(funcOp->getContext()); b.setInsertionPointToStart(&funcOp.body().front()); + + // Create BufferCastOps for function args. for (auto bbArg : funcOp.getArguments()) { auto tensorType = bbArg.getType().dyn_cast(); if (!tensorType) @@ -804,7 +805,9 @@ state.aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); state.mapBuffer(bbArg, bufferCast); } - return success(); + + // Bufferize function body. + return bufferize(&funcOp.body(), state); } //===----------------------------------------------------------------------===// @@ -923,37 +926,6 @@ return res; } -//===----------------------------------------------------------------------===// -// Bufferization entry-point for functions. -//===----------------------------------------------------------------------===// - -static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp, - BufferizationState &state) { - LLVM_DEBUG(llvm::dbgs() << "\n\n"); - LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); - OpBuilder b(funcOp->getContext()); - - // Start by bufferizing `funcOp` arguments. - if (failed(bufferize(b, funcOp, state))) - return failure(); - - auto walkFunc = [&](Operation *op) -> WalkResult { - if (failed(bufferizeOp(op, state))) - return failure(); - return success(); - }; - - // Bufferize ops pre-order, i.e., bufferize ops first, then their children. - // This is needed for ops with blocks that have BlockArguments. These must be - // mapped before bufferizing the children. - if (funcOp.walk(walkFunc).wasInterrupted()) - return failure(); - - LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); - - return success(); -} - //===----------------------------------------------------------------------===// // Bufferization entry-point for modules. //===----------------------------------------------------------------------===// @@ -1380,7 +1352,7 @@ // Bufferization phase. if (!options.testAnalysisOnly) { // Bufferize all ops in funcOp. - if (failed(bufferizeFuncOpInternals(funcOp, state))) + if (failed(bufferizeFuncOp(funcOp, state))) return failure(); // Erase all obsolete ops. @@ -1547,12 +1519,13 @@ BufferizationState &state) const { // TODO: Add bufferization support when needed. scf.execute_region should be // bufferized similar to scf.if. + auto executeRegionOp = cast(op); bool hasTensorReturnType = any_of( op->getResultTypes(), [](Type t) { return t.isa(); }); if (hasTensorReturnType) return op->emitError( "scf.execute_region with tensor result not supported"); - return success(); + return comprehensive_bufferize::bufferize(&executeRegionOp.region(), state); } }; @@ -1609,37 +1582,33 @@ LogicalResult bufferize(Operation *op, OpBuilder &b, BufferizationState &state) const { - // scf::IfOp is bufferized after scf::YieldOp in the else branch. - return success(); - } -}; + auto ifOp = cast(op); -/// Bufferize the scf::IfOp. This function is called after the YieldOp was -/// bufferized. -static LogicalResult bufferizeIfOp(scf::IfOp ifOp, OpBuilder &b, - BufferizationState &state) { - // Take a guard before anything else. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(ifOp); + // Bufferize then/else blocks. + if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state))) + return failure(); + if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state))) + return failure(); - for (OpResult opResult : ifOp->getResults()) { - if (!opResult.getType().isa()) - continue; - // TODO: Atm we bail on unranked TensorType because we don't know how to - // alloc an UnrankedMemRefType + its underlying ranked MemRefType. - assert(opResult.getType().isa() && - "unsupported unranked tensor"); + for (OpResult opResult : ifOp->getResults()) { + if (!opResult.getType().isa()) + continue; + // TODO: Atm we bail on unranked TensorType because we don't know how to + // alloc an UnrankedMemRefType + its underlying ranked MemRefType. + assert(opResult.getType().isa() && + "unsupported unranked tensor"); - Value resultBuffer = getResultBuffer(b, opResult, state); - if (!resultBuffer) - return failure(); + Value resultBuffer = getResultBuffer(b, opResult, state); + if (!resultBuffer) + return failure(); - state.aliasInfo.createAliasInfoEntry(resultBuffer); - state.mapBuffer(opResult, resultBuffer); - } + state.aliasInfo.createAliasInfoEntry(resultBuffer); + state.mapBuffer(opResult, resultBuffer); + } - return success(); -} + return success(); + } +}; struct ForOpInterface : public BufferizableOpInterface::ExternalModel(op); // Take a guard before anything else. @@ -1716,41 +1682,39 @@ state.mapBuffer(opResult, resultBuffer); } - return success(); - } -}; + // Bufferize loop body. + if (failed(comprehensive_bufferize::bufferize(&forOp.region(), state))) + return failure(); -/// Bufferize the scf::ForOp. This function is called after the YieldOp was -/// bufferized. -static LogicalResult bufferizeForOp(scf::ForOp forOp, OpBuilder &b, - BufferizationState &state) { - auto yieldOp = cast(&forOp.region().front().back()); - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; + // Finish bufferizing scf::ForOp. + auto yieldOp = cast(&forOp.region().front().back()); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - Value yieldedBuffer = state.lookupBuffer(operand.get()); - Value bbArgBuffer = state.lookupBuffer(bbArg); - if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, - bbArgBuffer)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - return yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; - } + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + Value yieldedBuffer = state.lookupBuffer(operand.get()); + Value bbArgBuffer = state.lookupBuffer(bbArg); + if (!state.aliasInfo.areEquivalentBufferizedValues(yieldedBuffer, + bbArgBuffer)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + return yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to an equivalent buffer to the matching" + << " enclosing scf::for operand"; + } - // Buffers are equivalent so the work is already done and we just yield - // the bbArg so that it later canonicalizes away. - operand.set(bbArg); + // Buffers are equivalent so the work is already done and we just yield + // the bbArg so that it later canonicalizes away. + operand.set(bbArg); + } + return success(); } - return success(); -} +}; struct YieldOpInterface : public BufferizableOpInterface::ExternalModel(op); - - if (auto execOp = dyn_cast(yieldOp->getParentOp())) { - if (execOp->getNumResults() != 0) - return execOp->emitError( - "expected result-less scf.execute_region containing op"); - return success(); - } - - // Bufferize scf::IfOp after bufferizing the scf::YieldOp in the else - // branch. - if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { - if (ifOp.elseYield() != yieldOp) - return success(); - return bufferizeIfOp(ifOp, b, state); - } - - // Bufferize scf::ForOp after bufferizing the scf::YieldOp. - if (auto forOp = dyn_cast(yieldOp->getParentOp())) - return bufferizeForOp(forOp, b, state); - - return yieldOp->emitError("expected scf::ForOp parent for scf::YieldOp"); + if (!isa( + yieldOp->getParentOp())) + return yieldOp->emitError("unsupported scf::YieldOp parent"); + return success(); } }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -340,7 +340,8 @@ static_cast(oldInputs.size()) + numNewInputBuffers, static_cast(oldOutputs.size()) + numNewOutputBuffers})); - return success(); + // Bufferize loop body. + return comprehensive_bufferize::bufferize(&tiledLoopOp.region(), state); } };