diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -551,41 +551,41 @@ } /// Determine which OpOperand* will alias with `result` if the op is bufferized -/// in place. -/// Return None if the owner of `opOperand` does not have known -/// bufferization aliasing behavior, which indicates that the op must allocate -/// all of its tensor results. -/// TODO: in the future this may need to evolve towards a list of OpOperand*. -static Optional getAliasingOpOperand(OpResult result) { +/// in place. Note that multiple OpOperands can may potentially alias with an +/// OpResult. E.g.: std.select in the future. +static SmallVector getAliasingOpOperand(OpResult result) { + SmallVector r; + // Unknown ops are handled conservatively and never bufferize in-place. if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp())) - return None; - return TypeSwitch(result.getDefiningOp()) - .Case([&](tensor::CastOp op) { return &op->getOpOperand(0); }) - .Case([&](ConstantOp op) { return nullptr; }) - .Case([&](ExtractSliceOp op) { return &op->getOpOperand(0); }) + return SmallVector(); + TypeSwitch(result.getDefiningOp()) + .Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); }) + .Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); }) // In the case of scf::ForOp, this currently assumes the iter_args / yield // are 1-1. This may fail and is verified at the end. // TODO: update this. .Case([&](scf::ForOp op) { - return &op.getIterOpOperands()[result.getResultNumber()]; + r.push_back(&op.getIterOpOperands()[result.getResultNumber()]); }) - .Case([&](InitTensorOp op) { return nullptr; }) - .Case([&](InsertSliceOp op) { return &op->getOpOperand(1); }) + .Case([&](InsertSliceOp op) { r.push_back(&op->getOpOperand(1)); }) .Case([&](LinalgOp op) { - return op.getOutputTensorOperands()[result.getResultNumber()]; + r.push_back(op.getOutputTensorOperands()[result.getResultNumber()]); }) .Case([&](TiledLoopOp op) { // TODO: TiledLoopOp helper method to avoid leaking impl details. - return &op->getOpOperand(op.getNumControlOperands() + - op.getNumInputs() + result.getResultNumber()); + r.push_back(&op->getOpOperand(op.getNumControlOperands() + + op.getNumInputs() + + result.getResultNumber())); }) - .Case([&](vector::TransferWriteOp op) { return &op->getOpOperand(1); }) - .Case([&](CallOpInterface op) { return nullptr; }) + .Case([&](vector::TransferWriteOp op) { + r.push_back(&op->getOpOperand(1)); + }) + .Case([&](auto op) {}) .Default([&](Operation *op) { op->dump(); llvm_unreachable("unexpected defining op"); - return nullptr; }); + return r; } /// If the an ExtractSliceOp is bufferized in-place, the source operand will @@ -879,8 +879,11 @@ /// dominance). bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference( OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const { - assert(getAliasingOpOperand(result) == &operand && +#ifndef NDEBUG + SmallVector opOperands = getAliasingOpOperand(result); + assert(llvm::find(opOperands, &operand) != opOperands.end() && "operand and result do not match"); +#endif // NDEBUG Operation *opToBufferize = result.getDefiningOp(); Value opResult = result; @@ -975,8 +978,11 @@ /// a write to a non-writable buffer. bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer( OpOperand &opOperand, OpResult opResult) const { - assert(getAliasingOpOperand(opResult) == &opOperand && +#ifndef NDEBUG + SmallVector opOperands = getAliasingOpOperand(opResult); + assert(llvm::find(opOperands, &opOperand) != opOperands.end() && "operand and result do not match"); +#endif // NDEBUG // Certain buffers are not writeable: // 1. A function bbArg that is not inplaceable or @@ -1126,9 +1132,10 @@ Operation *candidateOp = mit->v.getDefiningOp(); if (!candidateOp) continue; - auto maybeAliasingOperand = getAliasingOpOperand(mit->v.cast()); - if (!maybeAliasingOperand || !*maybeAliasingOperand || - !bufferizesToMemoryWrite(**maybeAliasingOperand)) + SmallVector operands = + getAliasingOpOperand(mit->v.cast()); + assert(operands.size() <= 1 && "more than 1 OpOperand not supported yet"); + if (operands.empty() || !bufferizesToMemoryWrite(*operands.front())) continue; LDBG("---->clobbering candidate: " << printOperationInfo(candidateOp) << '\n'); @@ -1414,9 +1421,11 @@ bool skipCopy = false) { OpBuilder::InsertionGuard guard(b); Operation *op = result.getOwner(); - Optional maybeOperand = getAliasingOpOperand(result); - assert(maybeOperand && "corresponding OpOperand not found"); - Value operand = (*maybeOperand)->get(); + SmallVector aliasingOperands = getAliasingOpOperand(result); + // TODO: Support multiple OpOperands. + assert(aliasingOperands.size() == 1 && + "more than 1 OpOperand not supported yet"); + Value operand = aliasingOperands.front()->get(); Value operandBuffer = lookup(bvm, operand); assert(operandBuffer && "operand buffer not found"); @@ -2159,8 +2168,11 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo, const DominanceInfo &domInfo) { - assert(getAliasingOpOperand(result) == &operand && +#ifndef NDEBUG + SmallVector opOperands = getAliasingOpOperand(result); + assert(llvm::find(opOperands, &operand) != opOperands.end() && "operand and result do not match"); +#endif // NDEBUG int64_t resultNumber = result.getResultNumber(); (void)resultNumber;