diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -479,14 +479,15 @@ /// Lookup the buffer for the given value. If the value was not bufferized /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, /// from which the memref operand is returned. -Value getBuffer(RewriterBase &rewriter, Value value, - const BufferizationOptions &options); +FailureOr getBuffer(RewriterBase &rewriter, Value value, + const BufferizationOptions &options); /// Return the buffer type for a given Value (tensor) after bufferization. /// /// Note: Op implementations should preferrably call `getBuffer()->getType()`. /// This function should only be used if `getBuffer` cannot be used. -BaseMemRefType getBufferType(Value value, const BufferizationOptions &options); +FailureOr getBufferType(Value value, + const BufferizationOptions &options); /// Replace an op with replacement values. The op is deleted. Tensor OpResults /// must be replaced with memref values. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -343,7 +343,7 @@ Return the bufferized type of the given tensor block argument. The block argument is guaranteed to belong to a block of this op. }], - /*retType=*/"BaseMemRefType", + /*retType=*/"FailureOr", /*methodName=*/"getBufferType", /*args=*/(ins "BlockArgument":$bbArg, "const BufferizationOptions &":$options), diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -84,8 +84,10 @@ auto castOp = cast(op); auto resultTensorType = castOp.getType().cast(); - Value source = getBuffer(rewriter, castOp.getIn(), options); - auto sourceType = source.getType().cast(); + FailureOr source = getBuffer(rewriter, castOp.getIn(), options); + if (failed(source)) + return failure(); + auto sourceType = source->getType().cast(); // Result type should have same layout and address space as the source type. BaseMemRefType resultType; @@ -100,7 +102,7 @@ } replaceOpWithNewBufferizedOp(rewriter, op, resultType, - source); + *source); return success(); } }; @@ -140,8 +142,14 @@ // instead of its OpOperands. In the worst case, 2 copies are inserted at // the moment (one for each tensor). When copying the op result, only one // copy would be needed. - Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options); - Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options); + FailureOr maybeTrueBuffer = + getBuffer(rewriter, selectOp.getTrueValue(), options); + FailureOr maybeFalseBuffer = + getBuffer(rewriter, selectOp.getFalseValue(), options); + if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer)) + return failure(); + Value trueBuffer = *maybeTrueBuffer; + Value falseBuffer = *maybeFalseBuffer; // The "true" and the "false" operands must have the same type. If the // buffers have different types, they differ only in their layout map. Cast diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -480,8 +480,8 @@ #endif } -Value bufferization::getBuffer(RewriterBase &rewriter, Value value, - const BufferizationOptions &options) { +FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, + const BufferizationOptions &options) { #ifndef NDEBUG auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); @@ -494,14 +494,17 @@ // Insert to_memref op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - Type memrefType = getBufferType(value, options); - ensureToMemrefOpIsValid(value, memrefType); - return rewriter.create(value.getLoc(), memrefType, - value); + FailureOr memrefType = getBufferType(value, options); + if (failed(memrefType)) + return failure(); + ensureToMemrefOpIsValid(value, *memrefType); + return rewriter + .create(value.getLoc(), *memrefType, value) + .getResult(); } /// Return the buffer type for a given Value (tensor) after bufferization. -BaseMemRefType +FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options) { auto tensorType = value.getType().dyn_cast(); assert(tensorType && "unexpected non-tensor type"); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -163,8 +163,12 @@ // Get "copy" buffer. Value copyBuffer; - if (getCopy()) - copyBuffer = getBuffer(rewriter, getCopy(), options); + if (getCopy()) { + FailureOr maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); + if (failed(maybeCopyBuffer)) + return failure(); + copyBuffer = *maybeCopyBuffer; + } // Compute memory space of this allocation. unsigned memorySpace; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -305,8 +305,13 @@ // Retrieve buffers for tensor operands. Value buffer = newOperands[idx]; - if (!buffer) - buffer = getBuffer(rewriter, opOperand.get(), options); + if (!buffer) { + FailureOr maybeBuffer = + getBuffer(rewriter, opOperand.get(), options); + if (failed(maybeBuffer)) + return failure(); + buffer = *maybeBuffer; + } // Caller / callee type mismatch is handled with a CastOp. auto memRefType = funcType.getInput(idx); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -44,15 +44,21 @@ newInputBuffers.push_back(opOperand->get()); continue; } - newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options)); + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); } // New output operands for the cloned op. SmallVector newOutputBuffers; for (OpResult opResult : op->getOpResults()) { OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber()); - Value resultBuffer = getBuffer(rewriter, opOperand->get(), options); - newOutputBuffers.push_back(resultBuffer); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); } // Merge input/output operands. diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -281,14 +281,17 @@ /// Helper function for loop bufferization. Return the bufferized values of the /// given OpOperands. If an operand is not a tensor, return the original value. -static SmallVector getBuffers(RewriterBase &rewriter, - MutableArrayRef operands, - const BufferizationOptions &options) { +static FailureOr> +getBuffers(RewriterBase &rewriter, MutableArrayRef operands, + const BufferizationOptions &options) { SmallVector result; for (OpOperand &opOperand : operands) { if (opOperand.get().getType().isa()) { - Value resultBuffer = getBuffer(rewriter, opOperand.get(), options); - result.push_back(resultBuffer); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand.get(), options); + if (failed(resultBuffer)) + return failure(); + result.push_back(*resultBuffer); } else { result.push_back(opOperand.get()); } @@ -298,36 +301,46 @@ /// Helper function for loop bufferization. Compute the buffer that should be /// yielded from a loop block (loop body or loop condition). -static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor, - BaseMemRefType type, - const BufferizationOptions &options) { +static FailureOr getYieldedBuffer(RewriterBase &rewriter, Value tensor, + BaseMemRefType type, + const BufferizationOptions &options) { assert(tensor.getType().isa() && "expected tensor"); ensureToMemrefOpIsValid(tensor, type); - Value yieldedVal = getBuffer(rewriter, tensor, options); - return castBuffer(rewriter, yieldedVal, type); + FailureOr yieldedVal = getBuffer(rewriter, tensor, options); + if (failed(yieldedVal)) + return failure(); + return castBuffer(rewriter, *yieldedVal, type); } /// Helper function for loop bufferization. Given a range of values, apply /// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified /// value in the result vector. -static SmallVector +static FailureOr> convertTensorValues(ValueRange values, const DenseSet &tensorIndices, - llvm::function_ref func) { + llvm::function_ref(Value, int64_t)> func) { SmallVector result; for (const auto &it : llvm::enumerate(values)) { size_t idx = it.index(); Value val = it.value(); - result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val); + if (tensorIndices.contains(idx)) { + FailureOr maybeVal = func(val, idx); + if (failed(maybeVal)) + return failure(); + result.push_back(*maybeVal); + } else { + result.push_back(val); + } } return result; } /// Helper function for loop bufferization. Given a list of pre-bufferization /// yielded values, compute the list of bufferized yielded values. -SmallVector getYieldedValues(RewriterBase &rewriter, ValueRange values, - TypeRange bufferizedTypes, - const DenseSet &tensorIndices, - const BufferizationOptions &options) { +FailureOr> +getYieldedValues(RewriterBase &rewriter, ValueRange values, + TypeRange bufferizedTypes, + const DenseSet &tensorIndices, + const BufferizationOptions &options) { return convertTensorValues( values, tensorIndices, [&](Value val, int64_t index) { return getYieldedBuffer(rewriter, val, @@ -342,10 +355,19 @@ SmallVector getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, const DenseSet &tensorIndices) { - return convertTensorValues( - bbArgs, tensorIndices, [&](Value val, int64_t index) { - return rewriter.create(val.getLoc(), val); - }); + SmallVector result; + for (const auto &it : llvm::enumerate(bbArgs)) { + size_t idx = it.index(); + Value val = it.value(); + if (tensorIndices.contains(idx)) { + result.push_back( + rewriter.create(val.getLoc(), val) + .getResult()); + } else { + result.push_back(val); + } + } + return result; } /// Bufferization of scf.for. Replace with a new scf.for that operates on @@ -445,8 +467,9 @@ return success(); } - BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg, - const BufferizationOptions &options) const { + FailureOr + getBufferType(Operation *op, BlockArgument bbArg, + const BufferizationOptions &options) const { auto forOp = cast(op); return bufferization::getBufferType( forOp.getOpOperandForRegionIterArg(bbArg).get(), options); @@ -462,8 +485,11 @@ DenseSet indices = getTensorIndices(forOp.getInitArgs()); // The new memref init_args of the loop. - SmallVector initArgs = + FailureOr> maybeInitArgs = getBuffers(rewriter, forOp.getIterOpOperands(), options); + if (failed(maybeInitArgs)) + return failure(); + SmallVector initArgs = *maybeInitArgs; // Construct a new scf.for op with memref instead of tensor values. auto newForOp = rewriter.create( @@ -689,13 +715,17 @@ getTensorIndices(whileOp.getAfterArguments()); // The new memref init_args of the loop. - SmallVector initArgs = + FailureOr> maybeInitArgs = getBuffers(rewriter, whileOp->getOpOperands(), options); + if (failed(maybeInitArgs)) + return failure(); + SmallVector initArgs = *maybeInitArgs; // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::to_vector( llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - return bufferization::getBufferType(bbArg, options).cast(); + // TODO: error handling + return bufferization::getBufferType(bbArg, options)->cast(); })); // Construct a new scf.while op with memref instead of tensor values. @@ -727,10 +757,12 @@ // Only equivalent buffers or new buffer allocations may be yielded to the // "after" region. // TODO: This could be relaxed for better bufferization results. - SmallVector newConditionArgs = + FailureOr> newConditionArgs = getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter, indicesAfter, options); - newConditionOp.getArgsMutable().assign(newConditionArgs); + if (failed(newConditionArgs)) + return failure(); + newConditionOp.getArgsMutable().assign(*newConditionArgs); // Set up new iter_args and move the loop body block to the new op. // The old block uses tensors, so wrap the (memref) bbArgs of the new block @@ -746,10 +778,12 @@ // Only equivalent buffers or new buffer allocations may be yielded to the // "before" region. // TODO: This could be relaxed for better bufferization results. - SmallVector newYieldValues = + FailureOr> newYieldValues = getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore, indicesBefore, options); - newYieldOp.getResultsMutable().assign(newYieldValues); + if (failed(newYieldValues)) + return failure(); + newYieldOp.getResultsMutable().assign(*newYieldValues); // Replace loop results. replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); @@ -849,13 +883,18 @@ for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); if (value.getType().isa()) { - Value buffer = getBuffer(rewriter, value, options); + FailureOr maybeBuffer = getBuffer(rewriter, value, options); + if (failed(maybeBuffer)) + return failure(); + Value buffer = *maybeBuffer; if (auto forOp = dyn_cast(yieldOp->getParentOp())) { - BaseMemRefType resultType = + FailureOr resultType = cast(forOp.getOperation()) .getBufferType(forOp.getRegionIterArgs()[it.index()], options); - buffer = castBuffer(rewriter, buffer, resultType); + if (failed(resultType)) + return failure(); + buffer = castBuffer(rewriter, buffer, *resultType); } newResults.push_back(buffer); } else { @@ -1078,16 +1117,22 @@ // If the op bufferizes out-of-place, allocate the copy before the // ForeachThreadOp. rewriter.setInsertionPoint(foreachThreadOp); - Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options); + FailureOr destBuffer = + getBuffer(rewriter, insertOp.getDest(), options); + if (failed(destBuffer)) + return failure(); // Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp. rewriter.setInsertionPoint(performConcurrentlyOp); - Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options); + FailureOr srcBuffer = + getBuffer(rewriter, insertOp.getSource(), options); + if (failed(srcBuffer)) + return failure(); Value subview = rewriter.create( - insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(), + insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(), insertOp.getMixedSizes(), insertOp.getMixedStrides()); // This memcpy will fold away if everything bufferizes in-place. - if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer, + if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer, subview))) return failure(); rewriter.eraseOp(op); @@ -1095,7 +1140,7 @@ // Replace all uses of ForeachThreadOp (just the corresponding result). rewriter.setInsertionPointAfter(foreachThreadOp); Value toTensorOp = - rewriter.create(foreachThreadOp.getLoc(), destBuffer); + rewriter.create(foreachThreadOp.getLoc(), *destBuffer); unsigned resultNum = 0; for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) { if (&nextOp == op) diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -130,10 +130,16 @@ const BufferizationOptions &options) const { auto yieldOp = cast(op); SmallVector newResults; - for (Value value : yieldOp.operands()) - newResults.push_back(value.getType().isa() - ? getBuffer(rewriter, value, options) - : value); + for (Value value : yieldOp.operands()) { + if (value.getType().isa()) { + FailureOr buffer = getBuffer(rewriter, value, options); + if (failed(buffer)) + return failure(); + newResults.push_back(*buffer); + } else { + newResults.push_back(value); + } + } replaceOpWithNewBufferizedOp(rewriter, op, newResults); return success(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -52,8 +52,11 @@ auto castOp = cast(op); // The result buffer still has the old (pre-cast) type. - Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options); - auto sourceMemRefType = resultBuffer.getType().cast(); + FailureOr resultBuffer = + getBuffer(rewriter, castOp.getSource(), options); + if (failed(resultBuffer)) + return failure(); + auto sourceMemRefType = resultBuffer->getType().cast(); TensorType resultTensorType = castOp.getResult().getType().cast(); MemRefLayoutAttrInterface layout; @@ -68,11 +71,11 @@ sourceMemRefType.getMemorySpaceAsInt()); // Replace the op with a memref.cast. - assert(memref::CastOp::areCastCompatible(resultBuffer.getType(), + assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), resultMemRefType) && "CallOp::bufferize: cast incompatible"); replaceOpWithNewBufferizedOp(rewriter, op, resultMemRefType, - resultBuffer); + *resultBuffer); return success(); } @@ -108,7 +111,11 @@ const BufferizationOptions &options) const { auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); - Value buffer = getBuffer(rewriter, collapseShapeOp.getSrc(), options); + FailureOr maybeBuffer = + getBuffer(rewriter, collapseShapeOp.getSrc(), options); + if (failed(maybeBuffer)) + return failure(); + Value buffer = *maybeBuffer; auto bufferType = buffer.getType().cast(); if (tensorResultType.getRank() == 0) { @@ -187,9 +194,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto dimOp = cast(op); - auto v = getBuffer(rewriter, dimOp.getSource(), options); - replaceOpWithNewBufferizedOp(rewriter, op, v, - dimOp.getIndex()); + FailureOr v = getBuffer(rewriter, dimOp.getSource(), options); + if (failed(v)) + return failure(); + replaceOpWithNewBufferizedOp(rewriter, op, *v, + dimOp.index()); return success(); } }; @@ -224,12 +233,15 @@ const BufferizationOptions &options) const { auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); - auto buffer = getBuffer(rewriter, expandShapeOp.getSrc(), options); + FailureOr buffer = + getBuffer(rewriter, expandShapeOp.getSrc(), options); + if (failed(buffer)) + return failure(); // Memref result type is inferred by the builder based on reassociation // indices and result shape. replaceOpWithNewBufferizedOp( - rewriter, op, tensorResultType.getShape(), buffer, + rewriter, op, tensorResultType.getShape(), *buffer, expandShapeOp.getReassociationIndices()); return success(); } @@ -268,8 +280,11 @@ // Even if this op was decided to bufferize out-of-place, do not insert the // buffer copy yet. This is done later in this function. - auto srcMemref = getBuffer(rewriter, extractSliceOp.getSource(), options); - auto srcMemrefType = srcMemref.getType().cast(); + FailureOr srcMemref = + getBuffer(rewriter, extractSliceOp.getSource(), options); + if (failed(srcMemref)) + return failure(); + auto srcMemrefType = srcMemref->getType().cast(); auto dstTensorType = extractSliceOp.getResult().getType().cast(); @@ -279,7 +294,7 @@ SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); OffsetSizeAndStrideOpInterface::expandToRank( - srcMemref, mixedOffsets, mixedSizes, mixedStrides, + *srcMemref, mixedOffsets, mixedSizes, mixedStrides, [&](Value target, int64_t dim) -> OpFoldResult { auto shapedType = target.getType().cast(); if (shapedType.isDynamicDim(dim)) @@ -292,7 +307,7 @@ mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, + loc, subviewMemRefType, *srcMemref, mixedOffsets, mixedSizes, mixedStrides); replaceOpWithBufferizedValues(rewriter, op, subView); @@ -322,9 +337,12 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto extractOp = cast(op); - Value srcMemref = getBuffer(rewriter, extractOp.getTensor(), options); - replaceOpWithNewBufferizedOp(rewriter, op, srcMemref, - extractOp.getIndices()); + FailureOr srcMemref = + getBuffer(rewriter, extractOp.getTensor(), options); + if (failed(srcMemref)) + return failure(); + replaceOpWithNewBufferizedOp(rewriter, op, *srcMemref, + extractOp.indices()); return success(); } }; @@ -497,10 +515,13 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto insertOp = cast(op); - Value destMemref = getBuffer(rewriter, insertOp.getDest(), options); + FailureOr destMemref = + getBuffer(rewriter, insertOp.getDest(), options); + if (failed(destMemref)) + return failure(); rewriter.create(insertOp.getLoc(), insertOp.getScalar(), - destMemref, insertOp.getIndices()); - replaceOpWithBufferizedValues(rewriter, op, destMemref); + *destMemref, insertOp.getIndices()); + replaceOpWithBufferizedValues(rewriter, op, *destMemref); return success(); } @@ -655,7 +676,10 @@ // TODO: be very loud about it or even consider failing the pass. auto insertSliceOp = cast(op); Location loc = insertSliceOp.getLoc(); - Value dstMemref = getBuffer(rewriter, insertSliceOp.getDest(), options); + FailureOr dstMemref = + getBuffer(rewriter, insertSliceOp.getDest(), options); + if (failed(dstMemref)) + return failure(); // Expand offsets, sizes and strides to the full rank to handle the // rank-reducing case. @@ -663,7 +687,7 @@ SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); OffsetSizeAndStrideOpInterface::expandToRank( - dstMemref, mixedOffsets, mixedSizes, mixedStrides, + *dstMemref, mixedOffsets, mixedSizes, mixedStrides, [&](Value target, int64_t dim) -> OpFoldResult { auto shapedType = target.getType().cast(); if (shapedType.isDynamicDim(dim)) @@ -671,23 +695,26 @@ return rewriter.getIndexAttr(shapedType.getDimSize(dim)); }); // Take a subview of the dst. - auto dstMemrefType = dstMemref.getType().cast(); + auto dstMemrefType = dstMemref->getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), dstMemrefType, mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, dstMemref, mixedOffsets, mixedSizes, + loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, mixedStrides); // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. - auto srcMemref = getBuffer(rewriter, insertSliceOp.getSource(), options); - if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView))) + FailureOr srcMemref = + getBuffer(rewriter, insertSliceOp.getSource(), options); + if (failed(srcMemref)) + return failure(); + if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) return failure(); - replaceOpWithBufferizedValues(rewriter, op, dstMemref); + replaceOpWithBufferizedValues(rewriter, op, *dstMemref); return success(); } }; @@ -714,9 +741,11 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto rankOp = cast(op); - auto v = getBuffer(rewriter, rankOp.getTensor(), options); + FailureOr v = getBuffer(rewriter, rankOp.getTensor(), options); + if (failed(v)) + return failure(); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), - v); + *v); return success(); } }; @@ -750,12 +779,16 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto reshapeOp = cast(op); - Value srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options); - Value shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options); + FailureOr srcBuffer = + getBuffer(rewriter, reshapeOp.getSource(), options); + FailureOr shapeBuffer = + getBuffer(rewriter, reshapeOp.getShape(), options); + if (failed(srcBuffer) || failed(shapeBuffer)) + return failure(); auto resultTensorType = reshapeOp.getResult().getType().cast(); auto resultMemRefType = getMemRefType(resultTensorType, options); replaceOpWithNewBufferizedOp( - rewriter, op, resultMemRefType, srcBuffer, shapeBuffer); + rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -50,9 +50,11 @@ auto readOp = cast(op); assert(readOp.getShapedType().isa() && "only tensor types expected"); - Value buffer = getBuffer(rewriter, readOp.getSource(), options); + FailureOr buffer = getBuffer(rewriter, readOp.getSource(), options); + if (failed(buffer)) + return failure(); replaceOpWithNewBufferizedOp( - rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(), + rewriter, readOp, readOp.getVectorType(), *buffer, readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(), readOp.getInBoundsAttr()); return success(); @@ -97,12 +99,15 @@ "only tensor types expected"); // Create a new transfer_write on buffer that doesn't have a return value. - Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options); + FailureOr resultBuffer = + getBuffer(rewriter, writeOp.getSource(), options); + if (failed(resultBuffer)) + return failure(); rewriter.create( - writeOp.getLoc(), writeOp.getVector(), resultBuffer, + writeOp.getLoc(), writeOp.getVector(), *resultBuffer, writeOp.getIndices(), writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr()); - replaceOpWithBufferizedValues(rewriter, op, resultBuffer); + replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); return success(); }