diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -115,7 +115,9 @@ if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. - Value buffer = *state.getBuffer(rewriter, srcOperand); + auto buffer = state.getBuffer(rewriter, srcOperand); + if (failed(buffer)) + return failure(); MemRefType resultType; if (bufferType.getLayout().isIdentity()) { @@ -138,7 +140,7 @@ } replaceOpWithNewBufferizedOp( - rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); + rewriter, op, resultType, *buffer, collapseShapeOp.reassociation()); return success(); } @@ -152,11 +154,13 @@ ? None : Optional( BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE); - Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace); + auto buffer = state.getBuffer(rewriter, srcOperand, overrideInPlace); + if (failed(buffer)) + return failure(); // Result type is inferred by the builder. replaceOpWithNewBufferizedOp( - rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); + rewriter, op, *buffer, collapseShapeOp.getReassociationIndices()); return success(); } }; @@ -183,8 +187,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto dimOp = cast(op); - Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); - replaceOpWithNewBufferizedOp(rewriter, op, v, dimOp.index()); + auto v = state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/); + if (failed(v)) + return failure(); + replaceOpWithNewBufferizedOp(rewriter, op, *v, + dimOp.index()); return success(); } }; @@ -219,13 +226,15 @@ BufferizationState &state) const { auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); - Value buffer = - *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); + auto buffer = + state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); + if (failed(buffer)) + return failure(); // Memref result type is inferred by the builder based on reassociation // indices and result shape. replaceOpWithNewBufferizedOp( - rewriter, op, tensorResultType.getShape(), buffer, + rewriter, op, tensorResultType.getShape(), *buffer, expandShapeOp.getReassociationIndices()); return success(); } @@ -264,10 +273,12 @@ // Even if this op was decided to bufferize out-of-place, do not insert the // buffer copy yet. This is done later in this function. - Value srcMemref = - *state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, - BufferizationState::ForceInPlacability::FORCE_INPLACE); - auto srcMemrefType = srcMemref.getType().cast(); + auto srcMemref = + state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/, + BufferizationState::ForceInPlacability::FORCE_INPLACE); + if (failed(srcMemref)) + return failure(); + auto srcMemrefType = srcMemref->getType().cast(); auto dstTensorType = extractSliceOp.result().getType().cast(); @@ -289,7 +300,7 @@ SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); OffsetSizeAndStrideOpInterface::expandToRank( - srcMemref, mixedOffsets, mixedSizes, mixedStrides, + *srcMemref, mixedOffsets, mixedSizes, mixedStrides, [&](Value target, int64_t dim) -> OpFoldResult { auto shapedType = target.getType().cast(); if (shapedType.isDynamicDim(dim)) @@ -302,7 +313,7 @@ mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, + loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, mixedStrides); // If not inplaceable, copy. @@ -342,9 +353,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto extractOp = cast(op); - Value srcMemref = - *state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); - replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, + auto srcMemref = + state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/); + if (failed(srcMemref)) + return failure(); + replaceOpWithNewBufferizedOp(rewriter, op, *srcMemref, extractOp.indices()); return success(); } @@ -703,10 +716,10 @@ // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. - Value srcMemref = - *state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); - if (failed(createMemCpy(rewriter, loc, srcMemref, subView, - state.getOptions()))) + auto srcMemref = + state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/); + if (failed(srcMemref) || failed(createMemCpy(rewriter, loc, *srcMemref, + subView, state.getOptions()))) return failure(); replaceOpWithBufferizedValues(rewriter, op, *dstMemref); @@ -736,9 +749,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto rankOp = cast(op); - Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); + auto v = state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); + if (failed(v)) + return failure(); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), - v); + *v); return success(); } };