diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -225,7 +225,7 @@ key.getDefiningOp()->getParentOfType()->dump(); } llvm::errs() << "NO VALUE FOR KEY: " << key << "\n"; - abort(); + return Value(); } return bvm.lookup(key); } @@ -537,9 +537,10 @@ /// the Linalg op. If the tensor is an "init" tensor (i.e. its value is /// actually used in the payload region), we additionally copy the original /// value into the newly allocated buffer. -static void allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, - SmallVectorImpl &resultBuffers, - BlockAndValueMapping &bvm) { +static LogicalResult +allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, + SmallVectorImpl &resultBuffers, + BlockAndValueMapping &bvm) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -561,7 +562,10 @@ auto inplaceOp = cast(op.getOperation()); OpResult tiedResult = inplaceOp.getTiedOpResult(opOperand); if (getInPlace(tiedResult) == InPlaceSpec::True) { - resultBuffers.push_back(lookup(bvm, output)); + Value v = lookup(bvm, output); + if (!v) + return failure(); + resultBuffers.push_back(v); continue; } @@ -571,11 +575,17 @@ resultBuffers.push_back(alloc); // Additionally, if the output buffer is used, clone its value for now. - if (op.payloadUsesValueFromOpOperand(&opOperand)) - b.create(loc, lookup(bvm, output), alloc); + if (op.payloadUsesValueFromOpOperand(&opOperand)) { + Value v = lookup(bvm, output); + if (!v) + return failure(); + b.create(loc, v, alloc); + } } if (op->getNumResults()) map(bvm, op->getResults(), resultBuffers); + + return success(); } static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op, @@ -611,10 +621,15 @@ Location loc = op.getLoc(); SmallVector newInputBuffers; newInputBuffers.reserve(op.getNumInputs()); - for (Value v : op.getInputs()) - newInputBuffers.push_back(lookup(bvm, v)); + for (Value in : op.getInputs()) { + Value v = lookup(bvm, in); + if (!v) + return failure(); + newInputBuffers.push_back(v); + } SmallVector newOutputBuffers; - allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm); + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm))) + return failure(); finalizeBufferAllocation(b, op, newInputBuffers, newOutputBuffers, bvm); return success(); } @@ -623,8 +638,12 @@ /// behind that will get DCE'd. static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, BlockAndValueMapping &bvm) { - if (dimOp.memrefOrTensor().getType().isa()) - dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor())); + if (dimOp.memrefOrTensor().getType().isa()) { + Value v = lookup(bvm, dimOp.memrefOrTensor()); + if (!v) + return failure(); + dimOp.memrefOrTensorMutable().assign(v); + } return success(); } @@ -664,8 +683,10 @@ auto tensorType = operand.get().getType().dyn_cast(); if (!tensorType) continue; - operand.set(b.create(returnOp.getLoc(), - lookup(bvm, operand.get()))); + Value v = lookup(bvm, operand.get()); + if (!v) + return failure(); + operand.set(b.create(returnOp.getLoc(), v)); } return success(); } @@ -689,10 +710,14 @@ map(bvm, subTensorInsertOp.result(), dstMemref); } else { dstMemref = lookup(bvm, subTensorInsertOp.dest()); + if (!dstMemref) + return failure(); } auto dstMemrefType = dstMemref.getType().cast(); Value srcMemref = lookup(bvm, subTensorInsertOp.source()); + if (!srcMemref) + return failure(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( subTensorInsertOp.getSourceType().getRank(), dstMemrefType, @@ -738,7 +763,10 @@ /// transfer_read from buffer if (auto readOp = dyn_cast(op.getOperation())) { - readOp.sourceMutable().assign(lookup(bvm, op.source())); + Value v = lookup(bvm, op.source()); + if (!v) + return failure(); + readOp.sourceMutable().assign(lookup(bvm, v)); return success(); } @@ -756,6 +784,8 @@ // InPlace write will result in memref.tensor_load(x) which must // canonicalize away with one of it uses. newInputBuffer = lookup(bvm, writeOp.source()); + if (!newInputBuffer) + return failure(); } // Create a new transfer_write on buffer that doesn't have a return value. @@ -793,8 +823,8 @@ // Skip BufferCast and TensorLoad ops. .Case( [&](auto) { return success(); }) - .Case( + .Case( [&](auto op) { return bufferize(b, op, bvm); }) .Default([&](Operation *op) { auto isaTensor = [](Type t) { return t.isa(); };