diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -25,6 +25,22 @@ class AnalysisState; class BufferizableOpInterface; +/// Specify fine-grain relationship between buffers to enable more analysis. +enum class BufferRelation { + Unknown, + // TODO: ResultContainsOperand, + // TODO: OperandContainsResult, + Equivalent +}; + +/// A list of possible aliasing OpOperands. This list models the runtime +/// aliasing relationship for an OpResult. +using AliasingOpOperandList = SmallVector; + +/// A list of possible aliasing OpResults. This list models the runtime +/// aliasing relationship for an OpOperand. +using AliasingOpResultList = SmallVector; + class OpFilter { public: /// An op filter entry. Filters can be used to specify which ops should be @@ -300,14 +316,6 @@ SmallVector stateInitializers; }; -/// Specify fine-grain relationship between buffers to enable more analysis. -enum class BufferRelation { - Unknown, - // TODO: ResultContainsOperand, - // TODO: OperandContainsResult, - Equivalent -}; - /// Return `true` if the given value is a BlockArgument of a func::FuncOp. bool isFunctionArgument(Value value); @@ -315,15 +323,15 @@ /// tensor values. class AnalysisState { public: - /// Determine which OpOperand* will alias with `opResult` if the op is + /// Determine which OpOperand* will alias with `result` if the op is /// bufferized in place. Return all tensor OpOperand* if the op is not /// bufferizable. - SmallVector getAliasingOpOperand(OpResult opResult) const; + AliasingOpOperandList getAliasingOpOperands(OpResult result) const; /// Determine which OpResult will alias with `opOperand` if the op is /// bufferized in place. Return all tensor OpResults if the op is not /// bufferizable. - SmallVector getAliasingOpResult(OpOperand &opOperand) const; + AliasingOpResultList getAliasingOpResults(OpOperand &opOperand) const; /// Return true if `opOperand` bufferizes to a memory read. Return `true` if /// the op is not bufferizable. @@ -566,6 +574,12 @@ const BufferizationOptions &options); namespace detail { +/// This is the default implementation of +/// BufferizableOpInterface::getAliasingOpOperands. Should not be called from +/// other places. +AliasingOpOperandList defaultGetAliasingOpOperands(OpResult opResult, + const AnalysisState &state); + /// This is the default implementation of /// BufferizableOpInterface::getBufferType. Should not be called from other /// places. @@ -585,13 +599,13 @@ bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, unsigned index); -/// This is the default implementation of getAliasingOpOperand in case the +/// This is the default implementation of getAliasingOpOperands in case the /// defining op does not implement the BufferizableOpInterface. -SmallVector unknownGetAliasingOpOperand(OpResult opResult); +AliasingOpOperandList unknownGetAliasingOpOperands(OpResult opResult); -/// This is the default implementation of getAliasingOpResult in case the -/// defining op does not implement the BufferizableOpInterface. -SmallVector unknownGetAliasingOpResult(OpOperand &opOperand); +/// This is the default implementation of getAliasingOpResults in case the +/// owner op does not implement the BufferizableOpInterface. +AliasingOpResultList unknownGetAliasingOpResults(OpOperand &opOperand); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -172,22 +172,52 @@ >, InterfaceMethod< /*desc=*/[{ - Return the OpResult that aliases with a given OpOperand when + Return the OpResults that may alias with a given OpOperand when bufferized in-place. This method will never be called on OpOperands that do not have a tensor type. - Note: This method can return multiple OpResults, indicating that a - given OpOperand may at runtime alias with any (or multiple) of the - returned OpResults. + This method can return multiple OpResults, indicating that a given + OpOperand may at runtime alias with any (or multiple) of the returned + OpResults. + + False positives are allowed in the list of OpResults, but they can + adversely affect the accuracy of the anlysis. On the contrary, + omitting potential aliases is incorrect. + + One possible (conservative) implementation of this interface method, + that is always safe, is to return all tensor OpResults. + + Examples: + + ``` + // aliasingOpResults(%t) = {%r} + %r = tensor.insert_slice %f into %t : tensor<10xf32> + + // aliasingOpResults(%t) = {%r} + %r = tensor.extract_slice %t[0]][5][1] + : tensor<10xf32> to tensor<5xf32> + + // aliasingOpResults(%t1) = {%r} + // aliasingOpResults(%t2) = {%r} + %r = arith.select %c, %t1, %t2 : tensor<10xf32> + + // A hypothetical op that bufferizes to rolling a dice and based on + // the result to either return buffer(%t) or a newly allocated copy + // thereof. + // aliasingOpResults(%t) = {%r} + %r = "dummy.alias_or_copy(%t) : (tensor<10xf32>) -> (tensor<10xf32>)" + ``` }], - /*retType=*/"::mlir::SmallVector<::mlir::OpResult>", - /*methodName=*/"getAliasingOpResult", + /*retType=*/"::mlir::bufferization::AliasingOpResultList", + /*methodName=*/"getAliasingOpResults", /*args=*/(ins "::mlir::OpOperand &":$opOperand, "const ::mlir::bufferization::AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. - llvm_unreachable("getAliasingOpResult not implemented"); + assert(opOperand.get().getType().isa() && + "expected OpOperand with tensor type"); + llvm_unreachable("getAliasingOpResults not implemented"); }] >, InterfaceMethod< @@ -196,35 +226,51 @@ bufferized in-place. This method will never be called on OpResults that do not have a tensor type. - By default, this method is the inverse of `getAliasingOpResult`. Ops + By default, this method is the inverse of `getAliasingOpResults`. Ops with a region that yield values may want to override this method to return the OpOperands that are yielded by the terminator. - Note: This method can return multiple OpOperands, indicating that the - given OpResult may at runtime alias with any (or multiple) of the - returned OpOperands. This can be useful for branches and for ops such - as `arith.select`. + This method can return multiple OpOperands, indicating that a given + OpResult may at runtime alias with any (or multiple) of the returned + OpOperands. + + False positives are allowed in the list of OpOperands, but they can + adversely affect the accuracy of the anlysis. On the contrary, + omitting potential aliases is incorrect. + + One possible (conservative) implementation of this interface method, + that is always safe, is to return all tensor OpOperands. + + Note: If the returned list of OpOperands is empty, this op definitely + bufferizes to a new allocation. In that case `bufferizesToAllocation` + must return `true`. + + Examples: + + ``` + // aliasingOpOperands(%r) = {%t} + %r = tensor.insert_slice %f into %t : tensor<10xf32> + + // aliasingOpOperands(%r) = {%t} + %r = tensor.extract_slice %t[0]][5][1] + : tensor<10xf32> to tensor<5xf32> + + // aliasingOpOperands(%r) = {%t1, %t2} + %r = arith.select %c, %t1, %t2 : tensor<10xf32> + + // aliasingOpOperands(%r) = {} + %r = te }], - /*retType=*/"::mlir::SmallVector<::mlir::OpOperand *>", - /*methodName=*/"getAliasingOpOperand", + /*retType=*/"::mlir::bufferization::AliasingOpOperandList", + /*methodName=*/"getAliasingOpOperands", /*args=*/(ins "::mlir::OpResult":$opResult, "const ::mlir::bufferization::AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opResult.getType().isa() && "expected OpResult with tensor type"); - SmallVector result; - auto bufferizableOp = - cast($_op.getOperation()); - for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) { - if (!opOperand.get().getType().isa()) - continue; - SmallVector aliasingOpResults = - bufferizableOp.getAliasingOpResult(opOperand, state); - if (llvm::is_contained(aliasingOpResults, opResult)) - result.push_back(&opOperand); - } - return result; + return ::mlir::bufferization::detail::defaultGetAliasingOpOperands( + opResult, state); }] >, InterfaceMethod< @@ -285,10 +331,11 @@ conversion. The implementation of this method must be consistent with the - remaining methods, in particular `getAliasingOpOperand`. I.e., a + remaining methods, in particular `getAliasingOpOperands`. I.e., a tensor result `r` may only be replaced with: - a) A buffer that aliases one of buffers in getAliasingOpOperand(r). - b) Or: A newly allocated buffer. + + a) One of the buffers in getAliasingOpOperands(r). + b) Or: A newly allocated buffer (only if `bufferizesToAllocation`). This method will never be called on ops that do not have at least one tensor operand/result. @@ -440,7 +487,7 @@ cast(getOperation()); return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state) - && !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); + && !bufferizableOp.getAliasingOpResults(opOperand, state).empty(); } }]; } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -100,7 +100,7 @@ bool bufferizesToMemoryWrite(OpOperand &opOperand, const AnalysisState &state); - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( OpOperand &opOperand, const AnalysisState &state); FailureOr getBufferType( @@ -248,7 +248,7 @@ return false; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -409,7 +409,7 @@ return true; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( OpOperand &opOperand, const AnalysisState &state) const { return {}; } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h @@ -36,7 +36,7 @@ return dstOp.isDpsInit(&opOperand); } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Output operands alias with their respective tied OpResults. auto dstOp = cast(op); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -74,7 +74,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getResult(0)}; } @@ -126,14 +126,14 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getOpResult(0) /*result*/}; } - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { return {&op->getOpOperand(1) /*true_value*/, &op->getOpOperand(2) /*false_value*/}; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -181,8 +181,8 @@ if (operandType.isa()) return op->emitError("copying of unranked tensors is not implemented"); - SmallVector aliasingOpResults = - state.getAliasingOpResult(opOperand); + AliasingOpResultList aliasingOpResults = + state.getAliasingOpResults(opOperand); // Is the result yielded from a block? Or are deallocations turned off // entirely? In either case, mark the allocation as "escaping", so that it // will not be deallocated. @@ -193,7 +193,7 @@ if (aliasingOpResults.size() == 1 && !state.bufferizesToMemoryWrite(opOperand) && - state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1 && + state.getAliasingOpOperands(aliasingOpResults.front()).size() == 1 && !aliasingOpResults.front().getType().isa()) { // The op itself does not write but may create exactly one alias. Instead // of copying the OpOperand, copy the OpResult. The OpResult can sometimes @@ -359,26 +359,26 @@ /// Determine which OpOperand* will alias with `opResult` if the op is /// bufferized in place. Return all tensor OpOperand* if the op is not /// bufferizable. -SmallVector -AnalysisState::getAliasingOpOperand(OpResult opResult) const { +AliasingOpOperandList +AnalysisState::getAliasingOpOperands(OpResult opResult) const { if (Operation *op = opResult.getDefiningOp()) if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) - return bufferizableOp.getAliasingOpOperand(opResult, *this); + return bufferizableOp.getAliasingOpOperands(opResult, *this); // The op is not bufferizable. - return detail::unknownGetAliasingOpOperand(opResult); + return detail::unknownGetAliasingOpOperands(opResult); } /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. Return all tensor OpResults if the op is not bufferizable. -SmallVector -AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { +AliasingOpResultList +AnalysisState::getAliasingOpResults(OpOperand &opOperand) const { if (auto bufferizableOp = getOptions().dynCastBufferizableOp(opOperand.getOwner())) - return bufferizableOp.getAliasingOpResult(opOperand, *this); + return bufferizableOp.getAliasingOpResults(opOperand, *this); // The op is not bufferizable. - return detail::unknownGetAliasingOpResult(opOperand); + return detail::unknownGetAliasingOpResults(opOperand); } /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the @@ -440,7 +440,7 @@ OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) - for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) + for (OpResult opResult : getAliasingOpResults(*uMaybeReading)) for (OpOperand &use : opResult.getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) @@ -470,14 +470,14 @@ OpResult opResult = value.cast(); BufferizableOpInterface bufferizableOp = options.dynCastBufferizableOp(opResult.getDefiningOp()); - SmallVector opOperands = getAliasingOpOperand(opResult); + AliasingOpOperandList aliases = getAliasingOpOperands(opResult); // Stop iterating in either one of these cases: // * The current op is not bufferizable or excluded in the filter. // * There are no OpOperands to follow. // * There is an OpOperand, but it is not an equivalent tensor (only if // `followEquivalentOnly` is set). - if (!bufferizableOp || opOperands.empty() || + if (!bufferizableOp || aliases.empty() || (followEquivalentOnly && bufferizableOp.bufferRelation(opResult, *this) != BufferRelation::Equivalent)) { @@ -486,7 +486,7 @@ continue; } - for (OpOperand *o : opOperands) + for (OpOperand *o : aliases) workingSet.insert(o->get()); } @@ -522,9 +522,9 @@ return true; // Do not copy if the tensor is never read. - SmallVector aliasingOpResults = getAliasingOpResult(opOperand); + AliasingOpResultList aliases = getAliasingOpResults(opOperand); if (!bufferizesToMemoryRead(opOperand) && - llvm::none_of(aliasingOpResults, + llvm::none_of(aliases, [&](OpResult opResult) { return isValueRead(opResult); })) return true; @@ -593,7 +593,7 @@ // Note: In the absence of detailed analysis information (e.g., there may be // no function call analysis information), this `getAliasingOpResult` is // conservative and may report additional OpResults as potentially aliasing. - for (OpResult opResult : getAliasingOpResult(*operand)) + for (OpResult opResult : getAliasingOpResults(*operand)) for (OpOperand &use : opResult.getUses()) worklist.push_back(&use); } @@ -821,8 +821,8 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite( OpResult opResult, const AnalysisState &state) { auto bufferizableOp = cast(opResult.getDefiningOp()); - SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(opResult, state); + AliasingOpOperandList opOperands = + bufferizableOp.getAliasingOpOperands(opResult, state); // Case 1: OpResults that have no aliasing OpOperand usually bufferize to // memory writes. @@ -882,6 +882,23 @@ return false; } +// Compute the AliasingOpOperandList for a given OpResult based on +// getAliasingOpResults. +AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( + OpResult opResult, const AnalysisState &state) { + Operation *op = opResult.getDefiningOp(); + AliasingOpOperandList result; + for (OpOperand &opOperand : op->getOpOperands()) { + if (!opOperand.get().getType().isa()) + continue; + AliasingOpResultList aliasingOpResults = + state.getAliasingOpResults(opOperand); + if (llvm::is_contained(aliasingOpResults, opResult)) + result.push_back(&opOperand); + } + return result; +} + FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { @@ -896,13 +913,12 @@ auto opResult = value.cast(); auto bufferizableOp = cast(op); AnalysisState state(options); - auto aliasingOperands = bufferizableOp.getAliasingOpOperand(opResult, state); - if (!aliasingOperands.empty() && - bufferizableOp.bufferRelation(opResult, state) == - BufferRelation::Equivalent) { + AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); + if (!aliases.empty() && bufferizableOp.bufferRelation(opResult, state) == + BufferRelation::Equivalent) { // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. - Value equivalentOperand = aliasingOperands.front()->get(); + Value equivalentOperand = aliases.front()->get(); return getBufferType(equivalentOperand, options, fixedTypes); } @@ -925,20 +941,20 @@ return regionInterface.isRepetitiveRegion(index); } -SmallVector -bufferization::detail::unknownGetAliasingOpOperand(OpResult opResult) { +AliasingOpOperandList +bufferization::detail::unknownGetAliasingOpOperands(OpResult opResult) { // Conservatively assume that everything is aliasing. - SmallVector r; + AliasingOpOperandList r; for (OpOperand &operand : opResult.getDefiningOp()->getOpOperands()) if (operand.get().getType().isa()) r.push_back(&operand); return r; } -SmallVector -bufferization::detail::unknownGetAliasingOpResult(OpOperand &opOperand) { +AliasingOpResultList +bufferization::detail::unknownGetAliasingOpResults(OpOperand &opOperand) { // Conservatively assume that everything is aliasing. - SmallVector r; + AliasingOpResultList r; for (OpResult result : opOperand.getOwner()->getOpResults()) if (result.getType().isa()) r.push_back(result); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -226,9 +226,9 @@ return false; } -SmallVector -AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, - const AnalysisState &state) { +AliasingOpResultList +AllocTensorOp::getAliasingOpResults(OpOperand &opOperand, + const AnalysisState &state) { // This is a new allocation. It does not alias with any other buffer. return {}; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -172,21 +172,21 @@ opOperand.getOperandNumber()); } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Any OpResult may be aliasing. - return detail::unknownGetAliasingOpResult(opOperand); + return detail::unknownGetAliasingOpResults(opOperand); // Get aliasing results from state. const FuncAnalysisState &funcState = getFuncAnalysisState(state); auto aliasingReturnVals = funcState.aliasingReturnVals.lookup(funcOp).lookup( opOperand.getOperandNumber()); - SmallVector result; + AliasingOpResultList result; for (int64_t resultIdx : aliasingReturnVals) result.push_back(callOp->getOpResult(resultIdx)); return result; @@ -209,7 +209,7 @@ if (maybeEquiv) { #ifndef NDEBUG SmallVector aliasingOpOperands = - getAliasingOpOperand(op, opResult, state); + getAliasingOpOperands(op, opResult, state); assert(aliasingOpOperands.size() == 1 && "expected exactly 1 aliasing OpOperand"); assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv && @@ -329,7 +329,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -147,7 +147,7 @@ if (inplaceBufferized.contains(&operand)) return; markInPlace(operand); - for (OpResult result : state.getAliasingOpResult(operand)) + for (OpResult result : state.getAliasingOpResults(operand)) aliasInfo.unionSets(result, operand.get()); ++statNumTensorInPlace; } @@ -636,9 +636,9 @@ // No conflict if the conflicting write and the definition are the same // use. - SmallVector aliasingOpResult = - state.getAliasingOpResult(*uConflictingWrite); - if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == definition) { + AliasingOpResultList aliases = + state.getAliasingOpResults(*uConflictingWrite); + if (aliases.size() == 1 && aliases[0] == definition) { LLVM_DEBUG(llvm::dbgs() << " no conflict: definition and write are same\n"); continue; @@ -697,7 +697,7 @@ // there would then be no flow of data from the extract_slice operand to // its result's uses.) if (!state.bufferizesToMemoryWrite(use)) { - SmallVector opResults = state.getAliasingOpResult(use); + AliasingOpResultList opResults = state.getAliasingOpResults(use); if (llvm::any_of(opResults, [&](OpResult r) { return state.isValueRead(r); })) res.insert(&use); @@ -743,7 +743,7 @@ DenseSet usesRead, usesWrite; getAliasingReads(usesRead, operand.get(), aliasInfo, state); getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); - for (OpResult result : state.getAliasingOpResult(operand)) { + for (OpResult result : state.getAliasingOpResults(operand)) { getAliasingReads(usesRead, result, aliasInfo, state); getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); } @@ -796,8 +796,8 @@ continue; // Follow reverse SSA use-def chain. - SmallVector aliasingOpOperands = - state.getAliasingOpOperand(opResult); + AliasingOpOperandList aliasingOpOperands = + state.getAliasingOpOperands(opResult); for (OpOperand *opOperand : aliasingOpOperands) if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) worklist.push_back(opOperand->get()); @@ -813,7 +813,7 @@ // Collect writes of all aliases of OpOperand and OpResult. DenseSet usesWrite; getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); - for (OpResult result : state.getAliasingOpResult(operand)) { + for (OpResult result : state.getAliasingOpResults(operand)) { getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) @@ -954,7 +954,7 @@ for (OpResult opResult : op->getOpResults()) if (opResult.getType().isa()) for (OpOperand *opOperand : - bufferizableOp.getAliasingOpOperand(opResult, state)) + bufferizableOp.getAliasingOpOperands(opResult, state)) if (state.isInPlace(*opOperand)) if (bufferizableOp.bufferRelation(opResult, state) == BufferRelation::Equivalent) diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -102,8 +102,8 @@ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Operand is written to if it has an aliasing OpResult. - auto bufferizableOp = cast(op); - return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); + auto dpsOp = cast(op); + return dpsOp.isDpsInit(&opOperand); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -57,7 +57,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -104,9 +104,9 @@ struct ExecuteRegionOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value @@ -164,9 +164,9 @@ /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else @@ -265,7 +265,7 @@ // yield values are equivalent to each other. auto bufferizableOp = cast(op); SmallVector yieldValues = - bufferizableOp.getAliasingOpOperand(opResult, state); + bufferizableOp.getAliasingOpOperands(opResult, state); assert(yieldValues.size() == 2 && "expected 2 yield values"); bool equivalentYields = state.areEquivalentBufferizedValues( yieldValues[0]->get(), yieldValues[1]->get()); @@ -443,7 +443,7 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); return {forOp.getResultForOpOperand(opOperand)}; @@ -657,7 +657,7 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto whileOp = cast(op); unsigned int idx = opOperand.getOperandNumber(); @@ -952,7 +952,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { if (isa(op->getParentOp())) return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; @@ -1051,7 +1051,7 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto foreachThreadOp = cast(op); return {foreachThreadOp.getTiedOpResult(&opOperand)}; diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -27,9 +27,9 @@ struct AssumingOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { // AssumingOps do not have tensor OpOperands. The yielded value can be any // SSA value that is in scope. To allow for use-def chain traversal through // AssumingOps in the analysis, the corresponding yield value is considered @@ -99,11 +99,13 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(isa(op->getParentOp()) && "expected that parent is an AssumingOp"); - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + OpResult opResult = + op->getParentOp()->getResult(opOperand.getOperandNumber()); + return {opResult}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -43,7 +43,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -73,7 +73,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -97,7 +97,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getOpResult(0)}; } @@ -136,7 +136,7 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // InsertOp returns an alias of its operand. assert(op->getNumResults() == 1); @@ -164,7 +164,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -185,7 +185,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -206,7 +206,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -227,7 +227,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -248,7 +248,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,7 +41,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getResult(0)}; } @@ -98,11 +98,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(0) /*src*/) - return {op->getOpResult(0)}; - return {}; + // TODO: CollapseShapeOp may allocate at runtime. + return {op->getOpResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -214,7 +213,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -263,11 +262,9 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(0) /*src*/) - return {op->getOpResult(0)}; - return {}; + return {op->getOpResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -324,7 +321,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { if (&opOperand == &op->getOpOperand(0) /*source*/) return {op->getOpResult(0)}; @@ -397,7 +394,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -837,7 +834,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -935,7 +932,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -968,7 +965,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getOpResult(0)}; } @@ -1000,9 +997,9 @@ struct ParallelInsertSliceOpInterface : public BufferizableOpInterface::ExternalModel< ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {}; + return {}; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,7 +41,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -110,7 +110,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; }