diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -1426,10 +1426,11 @@ /// Helper function for LinalgOp bufferization. /// When allocating a new buffer, analyze whether `op` wants to read form that /// buffer. Only in that case, a copy of the result buffer may be needed. -static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, - BlockAndValueMapping &bvm, - BufferizationAliasInfo &aliasInfo) { +static LogicalResult +allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -1441,11 +1442,15 @@ assert(opResult && "could not find correspond OpResult"); bool skipCopy = !op.payloadUsesValueFromOperand(opOperand); Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy); + if (!resultBuffer) + return failure(); resultBuffers.push_back(resultBuffer); } if (op->getNumResults()) map(bvm, op->getResults(), resultBuffers); + + return success(); } /// Generic conversion for any LinalgOp on tensors. @@ -1473,7 +1478,9 @@ } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. - allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, aliasInfo); + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, + aliasInfo))) + return failure(); // Clone the newly bufferized op. SmallVector newOperands = newInputBuffers; @@ -1642,6 +1649,8 @@ b.setInsertionPoint(castOp); Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo); + if (!resultBuffer) + return failure(); Type sourceType = resultBuffer.getType(); auto rankedMemRefType = sourceType.dyn_cast(); auto unrankedMemRefType = sourceType.dyn_cast(); @@ -1717,6 +1726,8 @@ // TODO: More general: Matching bbArg does not bufferize to a read. Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + if (!resultBuffer) + return failure(); OpOperand &opOperand = forOp.getOpOperandForResult(opResult); BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(opOperand); @@ -1838,6 +1849,8 @@ const OpResult &opResult = tiledLoopOp->getResult(resultIndex); OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex); Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo); + if (!resultBuffer) + return failure(); // Insert mapping and aliasing info. aliasInfo.createAliasInfoEntry(resultBuffer); @@ -1986,6 +1999,9 @@ // buffer. Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo); + if (!dstMemref) + return failure(); + auto dstMemrefType = dstMemref.getType().cast(); Value srcMemref = lookup(bvm, insertSliceOp.source()); @@ -2048,6 +2064,8 @@ // this point. auto writeOp = cast(op.getOperation()); Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo); + if (!resultBuffer) + return failure(); b.create( op.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(), writeOp.permutation_map(),