diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -180,9 +180,8 @@ SmallVector getAliasingOpOperand(OpResult result) const; /// Determine which OpResult will alias with `opOperand` if the op is - /// bufferized in place. Return an empty OpResult if the op is not - /// bufferizable. - OpResult getAliasingOpResult(OpOperand &opOperand) const; + /// bufferized in place. Return an empty vector if the op is not bufferizable. + SmallVector getAliasingOpResult(OpOperand &opOperand) const; /// Return true if `opOperand` bufferizes to a memory read. Return `true` if /// the op is not bufferizable. @@ -396,9 +395,10 @@ return {}; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -124,7 +124,7 @@ bufferized in-place. This method will never be called on OpOperands that do not have a tensor type. }], - /*retType=*/"OpResult", + /*retType=*/"SmallVector", /*methodName=*/"getAliasingOpResult", /*args=*/(ins "OpOperand &":$opOperand, "const BufferizationState &":$state), @@ -162,8 +162,10 @@ for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) { if (!opOperand.get().getType().isa()) continue; - if (bufferizableOp.getAliasingOpResult(opOperand, state) == - opResult) + SmallVector aliasingOpResults = + bufferizableOp.getAliasingOpResult(opOperand, state); + if (llvm::find(aliasingOpResults, opResult) + != aliasingOpResults.end()) result.push_back(&opOperand); } return result; @@ -304,8 +306,7 @@ cast(getOperation()); return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state) - && static_cast( - bufferizableOp.getAliasingOpResult(opOperand, state)); + && !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); } // TODO: The following two attributes should belong to the tensor dialect. diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -211,9 +211,9 @@ return true; } - OpResult getAliasingOpResult(OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector getAliasingOpResult( + OpOperand &opOperand, const BufferizationState &state) const { + return {}; } LogicalResult bufferize(RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp @@ -69,9 +69,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return op->getResult(0); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {op->getResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -114,9 +115,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return op->getOpResult(0) /*result*/; + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {op->getOpResult(0) /*result*/}; } SmallVector diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -87,12 +87,13 @@ } /// Determine which OpResult will alias with `opOperand` if the op is bufferized -/// in place. Return an empty OpResult if the op is not bufferizable. -OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { +/// in place. Return an empty vector if the op is not bufferizable. +SmallVector +BufferizationState::getAliasingOpResult(OpOperand &opOperand) const { if (auto bufferizableOp = dyn_cast(opOperand.getOwner())) return bufferizableOp.getAliasingOpResult(opOperand, *this); - return OpResult(); + return {}; } /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the @@ -144,8 +145,9 @@ OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) - for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses()) - workingSet.push_back(&use); + for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) + for (OpOperand &use : opResult.getUses()) + workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; } @@ -266,9 +268,10 @@ })) return resultBuffer; // Do not copy if the copied data is never read. - OpResult aliasingOpResult = getAliasingOpResult(opOperand); - if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) && - !isValueRead(aliasingOpResult)) + SmallVector aliasingOpResults = getAliasingOpResult(opOperand); + if (!aliasingOpResults.empty() && !bufferizesToMemoryRead(opOperand) && + llvm::none_of(aliasingOpResults, + [&](OpResult opResult) { return isValueRead(opResult); })) return resultBuffer; // Do not copy if this op does not read the data, but writes it. if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand)) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -140,7 +140,7 @@ void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand, BufferizationState &state) { markInPlace(operand); - if (OpResult result = state.getAliasingOpResult(operand)) + for (OpResult result : state.getAliasingOpResult(operand)) aliasInfo.unionSets(result, operand.get()); } @@ -196,8 +196,8 @@ for (OpOperand &opOperand : bufferizableOp->getOpOperands()) { if (opOperand.get().getType().isa()) if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) { - if (OpResult opResult = - bufferizableOp.getAliasingOpResult(opOperand, *this)) + for (OpResult opResult : + bufferizableOp.getAliasingOpResult(opOperand, *this)) aliasInfo.unionAliasSets(opOperand.get(), opResult); aliasInfo.markInPlace(opOperand); } @@ -404,7 +404,9 @@ // No conflict if the conflicting write and the last write are the same // use. - if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite) + SmallVector aliasingOpResult = + state.getAliasingOpResult(*uConflictingWrite); + if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) continue; // All requirements are met. Conflict found! @@ -477,7 +479,7 @@ DenseSet usesRead, usesWrite; getAliasingReads(usesRead, operand.get()); getAliasingInplaceWrites(usesWrite, operand.get()); - if (OpResult result = state.getAliasingOpResult(operand)) { + for (OpResult result : state.getAliasingOpResult(operand)) { getAliasingReads(usesRead, result); getAliasingInplaceWrites(usesWrite, result); } @@ -506,7 +508,7 @@ bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) || state.bufferizesToMemoryWrite(opOperand); - if (OpResult opResult = state.getAliasingOpResult(opOperand)) + for (OpResult opResult : state.getAliasingOpResult(opOperand)) hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state); return hasWrite; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -168,8 +168,7 @@ // Operand is written to if it has an aliasing OpResult. For more details, // see `computeAliasingPairs`. auto bufferizableOp = cast(op); - return static_cast( - bufferizableOp.getAliasingOpResult(opOperand, state)); + return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); } SmallVector @@ -185,13 +184,16 @@ return {}; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { auto genericOp = cast(op); // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`. DenseMap pairs = computeAliasingPairs(genericOp); - return pairs[&opOperand]; + if (!pairs.count(&opOperand)) + return {}; + return {pairs[&opOperand]}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -252,16 +254,19 @@ // Only operands with an aliasing OpResult (i.e., output operands) bufferize // to a memory write. - return static_cast( - bufferizableOp.getAliasingOpResult(opOperand, state)); + return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { auto tiledLoopOp = cast(op); // Output operands are tied to their corresponding OpResults. - return tiledLoopOp.getTiedOpResult(opOperand); + OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand); + if (!opResult) + return {}; + return {opResult}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -397,9 +402,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -723,25 +723,24 @@ funcOp.getArgument(opOperand.getOperandNumber())); } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); const ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + SmallVector result; for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults(); ++resultIdx) if (Optional maybeArgNumber = getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx)) if (*maybeArgNumber == opOperand.getOperandNumber()) - return callOp->getOpResult(resultIdx); + result.push_back(callOp->getOpResult(resultIdx)); - // Note: Returning a non-equivalent tensor from a FuncOp is currently not - // supported an will fail bufferization. (Even if allow-return-memref, it - // will fail when the function is called.) - return OpResult(); + return result; } SmallVector @@ -916,9 +915,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -278,12 +278,13 @@ return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { auto forOp = cast(op); if (!opOperand.get().getType().isa()) - return OpResult(); - return forOp.getResultForOpOperand(opOperand); + return {}; + return {forOp.getResultForOpOperand(opOperand)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -401,13 +402,14 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { if (isa(op->getParentOp())) - return op->getParentOp()->getResult(opOperand.getOperandNumber()); + return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; if (isa(op->getParentOp())) - return op->getParentOp()->getResult(opOperand.getOperandNumber()); - return OpResult(); + return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + return {}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -35,9 +35,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return op->getResult(0); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {op->getResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -93,9 +94,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -121,11 +123,12 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return &opOperand == &op->getOpOperand(0) /*source*/ - ? op->getResult(0) - : OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + if (&opOperand == &op->getOpOperand(0) /*source*/) + return {op->getOpResult(0)}; + return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -207,9 +210,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -371,11 +375,12 @@ return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { assert(&opOperand == &op->getOpOperand(1) /*dest*/ && "expected dest OpOperand"); - return op->getOpResult(0); + return {op->getOpResult(0)}; } SmallVector @@ -451,11 +456,12 @@ return &opOperand == &op->getOpOperand(1) /*dest*/; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return &opOperand == &op->getOpOperand(1) /*dest*/ - ? op->getResult(0) - : OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + if (&opOperand == &op->getOpOperand(1) /*dest*/) + return {op->getResult(0)}; + return {}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -606,9 +612,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -40,9 +40,10 @@ return false; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { - return OpResult(); + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return {}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -81,11 +82,12 @@ return true; } - OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, - const BufferizationState &state) const { + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { assert(opOperand.get().getType().isa() && "only tensor types expected"); - return op->getOpResult(0); + return {op->getOpResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult,