diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -445,6 +445,13 @@ return success(); } + BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg, + const BufferizationOptions &options) const { + auto forOp = cast(op); + return bufferization::getBufferType( + forOp.getOpOperandForRegionIterArg(bbArg).get(), options); + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto forOp = cast(op); @@ -474,20 +481,9 @@ getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); - // Erase terminator if present. - if (iterArgs.size() == 1) - rewriter.eraseOp(loopBody->getTerminator()); - // Move loop body to new loop. rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); - // Update scf.yield of new loop. - auto yieldOp = cast(loopBody->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - SmallVector yieldValues = getYieldedValues( - rewriter, yieldOp.getResults(), initArgsTypes, indices, options); - yieldOp.getResultsMutable().assign(yieldValues); - // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); @@ -844,9 +840,9 @@ yieldOp->getParentOp())) return yieldOp->emitError("unsupported scf::YieldOp parent"); - // TODO: Bufferize scf.yield inside scf.while/scf.for here. - // (Currently bufferized together with scf.while/scf.for.) - if (isa(yieldOp->getParentOp())) + // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized + // together with scf.while.) + if (isa(yieldOp->getParentOp())) return success(); SmallVector newResults; @@ -854,6 +850,13 @@ Value value = it.value(); if (value.getType().isa()) { Value buffer = getBuffer(rewriter, value, options); + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + BaseMemRefType resultType = + cast(forOp.getOperation()) + .getBufferType(forOp.getRegionIterArgs()[it.index()], + options); + buffer = castBuffer(rewriter, buffer, resultType); + } newResults.push_back(buffer); } else { newResults.push_back(value);