diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -61,42 +61,17 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto assumingOp = cast(op); - - // Compute new result types. - SmallVector newResultTypes; - for (Type type : assumingOp->getResultTypes()) { - if (auto tensorType = type.dyn_cast()) { - // TODO: Infer the result type instead of computing it. - newResultTypes.push_back(getMemRefType(tensorType, options)); - } else { - newResultTypes.push_back(type); - } - } + assert(assumingOp.getDoRegion().getBlocks().size() == 1 && + "only 1 block supported"); + auto yieldOp = cast( + assumingOp.getDoRegion().front().getTerminator()); // Create new op and move over region. + TypeRange newResultTypes(yieldOp.operands()); auto newOp = rewriter.create( op->getLoc(), newResultTypes, assumingOp.getWitness()); newOp.getDoRegion().takeBody(assumingOp.getRegion()); - // Update terminator. - assert(newOp.getDoRegion().getBlocks().size() == 1 && - "only 1 block supported"); - Block *newBlock = &newOp.getDoRegion().front(); - auto yieldOp = cast(newBlock->getTerminator()); - rewriter.setInsertionPoint(yieldOp); - SmallVector newYieldValues; - for (const auto &it : llvm::enumerate(yieldOp.operands())) { - Value val = it.value(); - if (val.getType().isa()) { - newYieldValues.push_back(rewriter.create( - yieldOp.getLoc(), newResultTypes[it.index()], val)); - } else { - newYieldValues.push_back(val); - } - } - rewriter.replaceOpWithNewOp(yieldOp, - newYieldValues); - // Update all uses of the old op. rewriter.setInsertionPointAfter(newOp); SmallVector newResults; @@ -153,7 +128,14 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { - // Op is bufferized as part of AssumingOp. + auto yieldOp = cast(op); + SmallVector newResults; + for (Value value : yieldOp.operands()) + newResults.push_back(value.getType().isa() + ? getBuffer(rewriter, value, options) + : value); + replaceOpWithNewBufferizedOp(rewriter, op, + newResults); return success(); } };