diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -25,6 +25,81 @@ class AnalysisState; class BufferizableOpInterface; +/// Specifies a fine-grain relationship between buffers to enable more analysis. +enum class BufferRelation { + None, + // TODO: ResultContainsOperand, + // TODO: OperandContainsResult, + Equivalent +}; + +/// Specifies the type of a property. +enum class Certainty { MayBe, MustBe }; + +/// An aliasing OpOperand. +struct AliasingOpOperand { + AliasingOpOperand(OpOperand *opOperand, BufferRelation relation) + : opOperand(opOperand), relation(relation) {} + + OpOperand *opOperand; + BufferRelation relation; +}; + +/// An aliasing OpResult. +struct AliasingOpResult { + AliasingOpResult(OpResult opResult, BufferRelation relation) + : opResult(opResult), relation(relation) {} + + OpResult opResult; + BufferRelation relation; +}; + +template class AliasList { +public: + /// Create an empty MayBe list of aliases. + AliasList() : certainty(Certainty::MayBe) {} + + /// Create an empty list of aliases with the specified certainty. + explicit AliasList(Certainty certainty) : certainty(certainty) {} + + /// Create a list of aliases with the specified certainty. + AliasList(Certainty certainty, std::initializer_list elems) + : certainty(certainty) { + for (T alias : elems) + addAlias(alias); + } + + /// Create a list of aliases with the specified certainty. + AliasList(Certainty certainty, SmallVector &&aliases) + : aliases(std::move(aliases)), certainty(certainty) {} + + ArrayRef getAliases() const { return aliases; } + + Certainty getCertainty() const { return certainty; } + + size_t getNumAliases() const { return aliases.size(); } + + void addAlias(T alias) { aliases.push_back(alias); } + + auto begin() const { return aliases.begin(); } + auto end() const { return aliases.end(); } + +private: + /// The list of aliases. + SmallVector aliases; + + /// Specifies the type of the property. + /// * MayBe: Either one of the entries may turn out to alias at runtime. The + /// runtime alias is guaranteed to be in the list. + /// * MustBe: Each entry is guaranteed to alias at runtime. + const Certainty certainty; +}; + +/// A list of aliasing OpOperands. +using AliasingOpOperandList = AliasList; +/// A list of aliasing OpResults. +using AliasingOpResultList = AliasList; + class OpFilter { public: /// An op filter entry. Filters can be used to specify which ops should be @@ -300,14 +375,6 @@ SmallVector stateInitializers; }; -/// Specify fine-grain relationship between buffers to enable more analysis. -enum class BufferRelation { - None, - // TODO: ResultContainsOperand, - // TODO: OperandContainsResult, - Equivalent -}; - /// Return `true` if the given value is a BlockArgument of a func::FuncOp. bool isFunctionArgument(Value value); @@ -315,15 +382,15 @@ /// tensor values. class AnalysisState { public: - /// Determine which OpOperand* will alias with `opResult` if the op is + /// Determine which OpOperand* will alias with `result` if the op is /// bufferized in place. Return all tensor OpOperand* if the op is not /// bufferizable. - SmallVector getAliasingOpOperand(OpResult opResult) const; + AliasingOpOperandList getAliasingOpOperands(OpResult result) const; /// Determine which OpResult will alias with `opOperand` if the op is /// bufferized in place. Return all tensor OpResults if the op is not /// bufferizable. - SmallVector getAliasingOpResult(OpOperand &opOperand) const; + AliasingOpResultList getAliasingOpResults(OpOperand &opOperand) const; /// Return true if `opOperand` bufferizes to a memory read. Return `true` if /// the op is not bufferizable. @@ -539,6 +606,12 @@ const BufferizationOptions &options); namespace detail { +/// This is the default implementation of +/// BufferizableOpInterface::getAliasingOpOperands. Should not be called from +/// other places. +AliasingOpOperandList defaultGetAliasingOpOperands(OpResult opResult, + const AnalysisState &state); + /// This is the default implementation of /// BufferizableOpInterface::getBufferType. Should not be called from other /// places. @@ -552,13 +625,13 @@ bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp, unsigned index); -/// This is the default implementation of getAliasingOpOperand in case the -/// defining op does not implement the BufferizableOpInterface. -SmallVector unknownGetAliasingOpOperand(OpResult opResult); +/// This is the default implementation of getAliasingOpOperands in case the +/// owner op does not implement the BufferizableOpInterface. +AliasingOpOperandList unknownGetAliasingOpOperands(OpResult opResult); -/// This is the default implementation of getAliasingOpResult in case the +/// This is the default implementation of getAliasingOpResults in case the /// defining op does not implement the BufferizableOpInterface. -SmallVector unknownGetAliasingOpResult(OpOperand &opOperand); +AliasingOpResultList unknownGetAliasingOpResults(OpOperand &opOperand); } // namespace detail } // namespace bufferization diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -110,15 +110,15 @@ /*defaultImplementation=*/[{ auto bufferizableOp = cast($_op.getOperation()); - SmallVector opOperands = - bufferizableOp.getAliasingOpOperand(opResult, state); - if (opOperands.empty()) + AliasingOpOperandList aliases = + bufferizableOp.getAliasingOpOperands(opResult, state); + if (aliases.getNumAliases() == 0) return true; return llvm::any_of( - opOperands, - [&](OpOperand *operand) { - return bufferizableOp.bufferizesToMemoryWrite(*operand, - state); + aliases, + [&](AliasingOpOperand alias) { + return bufferizableOp.bufferizesToMemoryWrite( + *alias.opOperand, state); }); }] >, @@ -140,22 +140,71 @@ >, InterfaceMethod< /*desc=*/[{ - Return the OpResult that aliases with a given OpOperand when + Return the OpResults that may alias with a given OpOperand when bufferized in-place. This method will never be called on OpOperands that do not have a tensor type. - Note: This method can return multiple OpResults, indicating that a - given OpOperand may at runtime alias with any (or multiple) of the - returned OpResults. + This method can return multiple OpResults, indicating that a given + OpOperand may at runtime alias with any (or multiple) of the returned + OpResults. + + This property is specified with a degree of certainty: + + * Certainty::MayBe: The returned set of OpResults is an exhaustive + set of possible aliases. At runtime, buffer(opOperand) may alias + with any (or none) of the buffers of the returned OpResults. False + positives are allowed. On the contrary, omitting potential aliases + is incorrect. + * Certainty::MustBe: The buffers of returned set of OpResults are + guaranteed to alias with buffer(opOperand). False positives are not + allowed. Omitting potential aliases is also incorrect. + + Certainty::MustBe is a stronger property than Certainty::MayBe and can + enable more efficient bufferization. When possible, aliasing should be + specified as a MustBe relationship. + + For each alias, a BufferRelation can be specified: + + * BufferRelation::Equivalent: Both aliases are the exact same buffer. + I.e., same size, no offset, same strides. + * BufferRelation::None: There is no further information apart from the + fact that both buffers alias. + + One possible (conservative) implementation of this interface method + that is always safe is to return all tensor OpResults with + BufferRelation::None and Certainty::Maybe. + + Examples: + + ``` + // aliasingOpResults(%t) = MustBe {Equivalent %r} + %r = tensor.insert_slice %f into %t : tensor<10xf32> + + // aliasingOpResults(%t) = MustBe {None %r} + %r = tensor.extract_slice %t[0]][5][1] + : tensor<10xf32> to tensor<5xf32> + + // aliasingOpResults(%t1) = MayBe {Equivalent %r} + // aliasingOpResults(%t2) = MayBe {Equivalent %r} + %r = arith.select %c, %t1, %t2 : tensor<10xf32> + + // A hypothetical op that bufferizes to roll a dice and based on the + // result ot either return buffer(%t) or a newly allocated copy + // thereof. + // aliasingOpResults(%t1) = MayBe {Equivalent %r} + %r = "dummy.alias_or_copy(%t) : (tensor<10xf32>) -> (tensor<10xf32>)" + ``` }], - /*retType=*/"::mlir::SmallVector<::mlir::OpResult>", - /*methodName=*/"getAliasingOpResult", + /*retType=*/"::mlir::bufferization::AliasingOpResultList", + /*methodName=*/"getAliasingOpResults", /*args=*/(ins "::mlir::OpOperand &":$opOperand, "const ::mlir::bufferization::AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. - llvm_unreachable("getAliasingOpResult not implemented"); + assert(opOperand.get().getType().isa() && + "expected OpOperand with tensor type"); + llvm_unreachable("getAliasingOpResults not implemented"); }] >, InterfaceMethod< @@ -164,57 +213,37 @@ bufferized in-place. This method will never be called on OpResults that do not have a tensor type. - By default, this method is the inverse of `getAliasingOpResult`. Ops + By default, this method is the inverse of `getAliasingOpResults`. Ops with a region that yield values may want to override this method to return the OpOperands that are yielded by the terminator. - Note: This method can return multiple OpOperands, indicating that the - given OpResult may at runtime alias with any (or multiple) of the - returned OpOperands. This can be useful for branches and for ops such - as `arith.select`. + Examples: + + ``` + // aliasingOpOperands(%r) = MustBe {Equivalent %t} + %r = tensor.insert_slice %f into %t : tensor<10xf32> + + // aliasingOpOperands(%r) = MustBe {None %t} + %r = tensor.extract_slice %t[0]][5][1] + : tensor<10xf32> to tensor<5xf32> + + // aliasingOpOperands(%r) = MayBe {Equivalent %t1, Equivalent %t2} + %r = arith.select %c, %t1, %t2 : tensor<10xf32> + + // aliasingOpOperands(%r) = MayBe {} + %r = tensor.empty() : tensor<10xf32> + ``` }], - /*retType=*/"::mlir::SmallVector<::mlir::OpOperand *>", - /*methodName=*/"getAliasingOpOperand", + /*retType=*/"::mlir::bufferization::AliasingOpOperandList", + /*methodName=*/"getAliasingOpOperands", /*args=*/(ins "::mlir::OpResult":$opResult, "const ::mlir::bufferization::AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opResult.getType().isa() && "expected OpResult with tensor type"); - SmallVector result; - auto bufferizableOp = - cast($_op.getOperation()); - for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) { - if (!opOperand.get().getType().isa()) - continue; - SmallVector aliasingOpResults = - bufferizableOp.getAliasingOpResult(opOperand, state); - if (llvm::is_contained(aliasingOpResults, opResult)) - result.push_back(&opOperand); - } - return result; - }] - >, - InterfaceMethod< - /*desc=*/[{ - Return the buffer relation between the given OpResult and its aliasing - OpOperands when bufferized in-place. Most OpOperands have an - "equivalence" relation. This method will never be called on OpResults - that do not have a tensor type. It will also never be called on - OpResults that do not have at least one aliasing OpOperand. - - TODO: Support other relations such as "OpOperand is included in - OpResult". - }], - /*retType=*/"::mlir::bufferization::BufferRelation", - /*methodName=*/"bufferRelation", - /*args=*/(ins "::mlir::OpResult":$opResult, - "const ::mlir::bufferization::AnalysisState &":$state), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Does not have to be implemented for ops without tensor OpResults - // that have an aliasing OpOperand. - llvm_unreachable("bufferRelation not implemented"); + return ::mlir::bufferization::detail::defaultGetAliasingOpOperands( + opResult, state); }] >, InterfaceMethod< @@ -253,11 +282,15 @@ conversion. The implementation of this method must be consistent with the - remaining methods, in particular `getAliasingOpOperand`. I.e., a + remaining methods, in particular `getAliasingOpOperands`. I.e., a tensor result `r` may only be replaced with: - a) A buffer that aliases one of buffers in getAliasingOpOperand(r). + + a) A buffer that aliases one of buffers in getAliasingOpOperands(r). b) Or: A newly allocated buffer. + If `getAliasingOpOperands` is specified as Certainty::MustBe, the + second option is not available. + This method will never be called on ops that do not have at least one tensor operand/result. @@ -408,7 +441,8 @@ cast(getOperation()); return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state) - && !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); + && bufferizableOp.getAliasingOpResults(opOperand, state) + .getNumAliases() != 0; } }]; } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -99,7 +99,7 @@ bool bufferizesToMemoryWrite(OpOperand &opOperand, const AnalysisState &state); - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( OpOperand &opOperand, const AnalysisState &state); FailureOr getBufferType( @@ -247,7 +247,7 @@ return false; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -408,7 +408,7 @@ return true; } - SmallVector getAliasingOpResult( + AliasingOpResultList getAliasingOpResults( OpOperand &opOperand, const AnalysisState &state) const { return {}; } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h @@ -36,21 +36,16 @@ return dstOp.isDpsInit(&opOperand); } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Output operands alias with their respective tied OpResults. auto dstOp = cast(op); if (dstOp.isDpsInit(&opOperand)) - return {dstOp.getTiedOpResult(&opOperand)}; + return { + Certainty::MustBe, + {{dstOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; return {}; } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - assert(isa(op) && - "expected that op implements DestinationStyleOpInterface"); - return BufferRelation::Equivalent; - } }; } // namespace bufferization diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -74,14 +74,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{op->getResult(0), BufferRelation::Equivalent}}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -126,16 +122,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0) /*result*/}; - } - - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return {&op->getOpOperand(1) /*true_value*/, - &op->getOpOperand(2) /*false_value*/}; + return {Certainty::MayBe, + {{op->getOpResult(0) /*result*/, BufferRelation::Equivalent}}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -199,11 +189,6 @@ memrefType.getElementType()), memrefType.getMemorySpace()); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::None; - } }; } // namespace diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -177,28 +177,30 @@ if (operandType.isa()) return op->emitError("copies of unranked tensors are not supported"); - SmallVector aliasingOpResults = - state.getAliasingOpResult(opOperand); + AliasingOpResultList aliasingOpResults = + state.getAliasingOpResults(opOperand); // Is the result yielded from a block? Or are deallocations turned off // entirely? In either case, mark the allocation as "escaping", so that it // will not be deallocated. bool escape = !state.getOptions().createDeallocs || - llvm::any_of(aliasingOpResults, [&](Value v) { - return state.isTensorYielded(v); + llvm::any_of(aliasingOpResults, [&](AliasingOpResult a) { + return state.isTensorYielded(a.opResult); }); - if (aliasingOpResults.size() == 1 && + if (aliasingOpResults.getNumAliases() == 1 && !state.bufferizesToMemoryWrite(opOperand) && - state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { + state.getAliasingOpOperands(aliasingOpResults.getAliases()[0].opResult) + .getNumAliases() == 1) { // The op itself does not write but may create exactly one alias. Instead // of copying the OpOperand, copy the OpResult. The OpResult can sometimes // be smaller than the OpOperand (e.g., in the case of an extract_slice, // where the result is usually a smaller part of the source). - outOfPlaceOpResults.push_back(aliasingOpResults.front()); + OpResult opResult = aliasingOpResults.getAliases()[0].opResult; + outOfPlaceOpResults.push_back(opResult); if (!state.canOmitTensorCopy(opOperand)) - copiedOpResults.insert(aliasingOpResults.front()); + copiedOpResults.insert(opResult); if (escape) - escapingOpResultCopies.insert(aliasingOpResults.front()); + escapingOpResultCopies.insert(opResult); } else { // In all other cases, make a copy of the OpOperand. outOfPlaceOpOperands.push_back(&opOperand); @@ -352,26 +354,26 @@ /// Determine which OpOperand* will alias with `opResult` if the op is /// bufferized in place. Return all tensor OpOperand* if the op is not /// bufferizable. -SmallVector -AnalysisState::getAliasingOpOperand(OpResult opResult) const { +AliasingOpOperandList +AnalysisState::getAliasingOpOperands(OpResult opResult) const { if (Operation *op = opResult.getDefiningOp()) if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) - return bufferizableOp.getAliasingOpOperand(opResult, *this); + return bufferizableOp.getAliasingOpOperands(opResult, *this); // The op is not bufferizable. - return detail::unknownGetAliasingOpOperand(opResult); + return detail::unknownGetAliasingOpOperands(opResult); } /// Determine which OpResult will alias with `opOperand` if the op is bufferized /// in place. Return all tensor OpResults if the op is not bufferizable. -SmallVector -AnalysisState::getAliasingOpResult(OpOperand &opOperand) const { +AliasingOpResultList +AnalysisState::getAliasingOpResults(OpOperand &opOperand) const { if (auto bufferizableOp = getOptions().dynCastBufferizableOp(opOperand.getOwner())) - return bufferizableOp.getAliasingOpResult(opOperand, *this); + return bufferizableOp.getAliasingOpResults(opOperand, *this); // The op is not bufferizable. - return detail::unknownGetAliasingOpResult(opOperand); + return detail::unknownGetAliasingOpResults(opOperand); } /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the @@ -423,8 +425,8 @@ OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) - for (OpResult opResult : getAliasingOpResult(*uMaybeReading)) - for (OpOperand &use : opResult.getUses()) + for (AliasingOpResult alias : getAliasingOpResults(*uMaybeReading)) + for (OpOperand &use : alias.opResult.getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; @@ -453,23 +455,25 @@ OpResult opResult = value.cast(); BufferizableOpInterface bufferizableOp = options.dynCastBufferizableOp(opResult.getDefiningOp()); - SmallVector opOperands = getAliasingOpOperand(opResult); + AliasingOpOperandList aliases = getAliasingOpOperands(opResult); // Stop iterating in either one of these cases: // * The current op is not bufferizable or excluded in the filter. // * There are no OpOperands to follow. - // * There is an OpOperand, but it is not an equivalent tensor (only if - // `followEquivalentOnly` is set). - if (!bufferizableOp || opOperands.empty() || - (followEquivalentOnly && - bufferizableOp.bufferRelation(opResult, *this) != - BufferRelation::Equivalent)) { + if (!bufferizableOp || aliases.getNumAliases() == 0) { result.insert(value); continue; } - for (OpOperand *o : opOperands) - workingSet.insert(o->get()); + for (AliasingOpOperand a : aliases) { + if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) { + // Stop iterating if `followEquivalentOnly` is set but the alias is not + // equivalent. + result.insert(value); + } else { + workingSet.insert(a.opOperand->get()); + } + } } return result; @@ -510,10 +514,10 @@ return true; // Do not copy if the tensor is never read. - SmallVector aliasingOpResults = getAliasingOpResult(opOperand); + AliasingOpResultList aliases = getAliasingOpResults(opOperand); if (!bufferizesToMemoryRead(opOperand) && - llvm::none_of(aliasingOpResults, - [&](OpResult opResult) { return isValueRead(opResult); })) + llvm::none_of( + aliases, [&](AliasingOpResult a) { return isValueRead(a.opResult); })) return true; // Default: Cannot omit the copy. @@ -581,8 +585,8 @@ // Note: In the absence of detailed analysis information (e.g., there may be // no function call analysis information), this `getAliasingOpResult` is // conservative and may report additional OpResults as potentially aliasing. - for (OpResult opResult : getAliasingOpResult(*operand)) - for (OpOperand &use : opResult.getUses()) + for (AliasingOpResult alias : getAliasingOpResults(*operand)) + for (OpOperand &use : alias.opResult.getUses()) worklist.push_back(&use); } @@ -806,6 +810,31 @@ // Default implementations of interface methods //===----------------------------------------------------------------------===// +// Compute the AliasingOpOperandList for a given OpResult based on +// getAliasingOpResults. +AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( + OpResult opResult, const AnalysisState &state) { + Operation *op = opResult.getDefiningOp(); + SmallVector result; + std::optional certainty = {}; + for (OpOperand &opOperand : op->getOpOperands()) { + if (!opOperand.get().getType().isa()) + continue; + AliasingOpResultList aliasingOpResults = + state.getAliasingOpResults(opOperand); + for (const auto &it : aliasingOpResults) { + if (it.opResult == opResult) { + result.emplace_back(&opOperand, it.relation); + if (!certainty.has_value()) + certainty = aliasingOpResults.getCertainty(); + } + break; + } + } + return AliasingOpOperandList( + certainty.has_value() ? *certainty : Certainty::MayBe, std::move(result)); +} + FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, const DenseMap &fixedTypes) { @@ -818,15 +847,13 @@ // Value is an OpResult. Operation *op = getOwnerOfValue(value); auto opResult = value.cast(); - auto bufferizableOp = cast(op); AnalysisState state(options); - auto aliasingOperands = bufferizableOp.getAliasingOpOperand(opResult, state); - if (!aliasingOperands.empty() && - bufferizableOp.bufferRelation(opResult, state) == - BufferRelation::Equivalent) { + AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); + if (aliases.getNumAliases() > 0 && + aliases.getAliases()[0].relation == BufferRelation::Equivalent) { // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. - Value equivalentOperand = aliasingOperands.front()->get(); + Value equivalentOperand = aliases.getAliases().front().opOperand->get(); return getBufferType(equivalentOperand, options, fixedTypes); } @@ -849,22 +876,22 @@ return regionInterface.isRepetitiveRegion(index); } -SmallVector -bufferization::detail::unknownGetAliasingOpOperand(OpResult opResult) { +AliasingOpOperandList +bufferization::detail::unknownGetAliasingOpOperands(OpResult opResult) { // Conservatively assume that everything is aliasing. - SmallVector r; + AliasingOpOperandList r(Certainty::MayBe); for (OpOperand &operand : opResult.getDefiningOp()->getOpOperands()) if (operand.get().getType().isa()) - r.push_back(&operand); + r.addAlias({&operand, BufferRelation::None}); return r; } -SmallVector -bufferization::detail::unknownGetAliasingOpResult(OpOperand &opOperand) { +AliasingOpResultList +bufferization::detail::unknownGetAliasingOpResults(OpOperand &opOperand) { // Conservatively assume that everything is aliasing. - SmallVector r; + AliasingOpResultList r(Certainty::MayBe); for (OpResult result : opOperand.getOwner()->getOpResults()) if (result.getType().isa()) - r.push_back(result); + r.addAlias({result, BufferRelation::None}); return r; } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -226,9 +226,9 @@ return false; } -SmallVector -AllocTensorOp::getAliasingOpResult(OpOperand &opOperand, - const AnalysisState &state) { +AliasingOpResultList +AllocTensorOp::getAliasingOpResults(OpOperand &opOperand, + const AnalysisState &state) { // This is a new allocation. It does not alias with any other buffer. return {}; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -172,52 +172,37 @@ opOperand.getOperandNumber()); } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Any OpResult may be aliasing. - return detail::unknownGetAliasingOpResult(opOperand); + return detail::unknownGetAliasingOpResults(opOperand); // Get aliasing results from state. const FuncAnalysisState &funcState = getFuncAnalysisState(state); auto aliasingReturnVals = funcState.aliasingReturnVals.lookup(funcOp).lookup( opOperand.getOperandNumber()); - SmallVector result; - for (int64_t resultIdx : aliasingReturnVals) - result.push_back(callOp->getOpResult(resultIdx)); - return result; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - if (getFuncOpAnalysisState(state, funcOp) != - FuncOpAnalysisState::Analyzed) { - // Function not analyzed yet. The conservative answer is "None". - return BufferRelation::None; - } - const FuncAnalysisState &funcState = getFuncAnalysisState(state); - std::optional maybeEquiv = - getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); - if (maybeEquiv) { -#ifndef NDEBUG - SmallVector aliasingOpOperands = - getAliasingOpOperand(op, opResult, state); - assert(aliasingOpOperands.size() == 1 && - "expected exactly 1 aliasing OpOperand"); - assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv && + // Check if the aliasing OpResult is equivalent to the OpOperand. + std::optional equivalent = {}; + if (aliasingReturnVals.size() == 1) { + equivalent = getEquivalentFuncArgIdx(funcOp, funcState, + aliasingReturnVals.front()); + assert((!equivalent.has_value() || + *equivalent == opOperand.getOperandNumber()) && "inconsistent analysis state"); -#endif - return BufferRelation::Equivalent; } - return BufferRelation::None; + AliasingOpResultList result(equivalent.has_value() ? Certainty::MustBe + : Certainty::MayBe); + for (int64_t resultIdx : aliasingReturnVals) + result.addAlias({callOp->getOpResult(resultIdx), + equivalent.has_value() ? BufferRelation::Equivalent + : BufferRelation::None}); + return result; } /// All function arguments are writable. It is the responsibility of the @@ -329,7 +314,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -147,8 +147,8 @@ if (inplaceBufferized.contains(&operand)) return; markInPlace(operand); - for (OpResult result : state.getAliasingOpResult(operand)) - aliasInfo.unionSets(result, operand.get()); + for (AliasingOpResult alias : state.getAliasingOpResults(operand)) + aliasInfo.unionSets(alias.opResult, operand.get()); ++statNumTensorInPlace; } @@ -661,9 +661,10 @@ // No conflict if the conflicting write and the last write are the same // use. - SmallVector aliasingOpResult = - state.getAliasingOpResult(*uConflictingWrite); - if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite) { + AliasingOpResultList aliases = + state.getAliasingOpResults(*uConflictingWrite); + if (aliases.getNumAliases() == 1 && + aliases.getAliases()[0].opResult == lastWrite) { LLVM_DEBUG(llvm::dbgs() << " no conflict: last write and write are same\n"); continue; @@ -722,9 +723,10 @@ // there would then be no flow of data from the extract_slice operand to // its result's uses.) if (!state.bufferizesToMemoryWrite(use)) { - SmallVector opResults = state.getAliasingOpResult(use); - if (llvm::any_of(opResults, - [&](OpResult r) { return state.isValueRead(r); })) + AliasingOpResultList aliases = state.getAliasingOpResults(use); + if (llvm::any_of(aliases, [&](AliasingOpResult a) { + return state.isValueRead(a.opResult); + })) res.insert(&use); } } @@ -768,9 +770,9 @@ DenseSet usesRead, usesWrite; getAliasingReads(usesRead, operand.get(), aliasInfo, state); getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); - for (OpResult result : state.getAliasingOpResult(operand)) { - getAliasingReads(usesRead, result, aliasInfo, state); - getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + for (AliasingOpResult alias : state.getAliasingOpResults(operand)) { + getAliasingReads(usesRead, alias.opResult, aliasInfo, state); + getAliasingInplaceWrites(usesWrite, alias.opResult, aliasInfo, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -821,11 +823,10 @@ continue; // Follow reverse SSA use-def chain. - SmallVector aliasingOpOperands = - state.getAliasingOpOperand(opResult); - for (OpOperand *opOperand : aliasingOpOperands) - if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) - worklist.push_back(opOperand->get()); + for (AliasingOpOperand alias : state.getAliasingOpOperands(opResult)) + if (aliasInfo.isInPlace(*alias.opOperand) || + currentOpOperand == alias.opOperand) + worklist.push_back(alias.opOperand->get()); } return false; } @@ -838,8 +839,8 @@ // Collect writes of all aliases of OpOperand and OpResult. DenseSet usesWrite; getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); - for (OpResult result : state.getAliasingOpResult(operand)) { - getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + for (AliasingOpResult alias : state.getAliasingOpResults(operand)) { + getAliasingInplaceWrites(usesWrite, alias.opResult, aliasInfo, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -974,16 +975,25 @@ static void equivalenceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, AnalysisState &state) { - for (Operation *op : ops) - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) - for (OpResult opResult : op->getOpResults()) - if (opResult.getType().isa()) - for (OpOperand *opOperand : - bufferizableOp.getAliasingOpOperand(opResult, state)) - if (state.isInPlace(*opOperand)) - if (bufferizableOp.bufferRelation(opResult, state) == - BufferRelation::Equivalent) - aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); + for (Operation *op : ops) { + for (OpOperand &opOperand : op->getOpOperands()) { + if (opOperand.get().getType().isa()) { + if (!state.isInPlace(opOperand)) + // Out-of-place OpOperands bufferize to new allocations and do not + // union equivalence sets. + continue; + AliasingOpResultList aliases = state.getAliasingOpResults(opOperand); + if (aliases.getCertainty() == Certainty::MayBe) + // If it is not certain that the alias will materialize at runtime, + // we have to be conservative. + continue; + for (AliasingOpResult alias : aliases) { + if (alias.relation == BufferRelation::Equivalent) + aliasInfo.unionEquivalenceClasses(alias.opResult, opOperand.get()); + } + } + } + } } /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained @@ -995,7 +1005,7 @@ SmallVector ops; op->walk([&](Operation *op) { // No tensors => no buffers. - if (none_of(op->getResultTypes(), isaTensor)) + if (none_of(op->getOperandTypes(), isaTensor)) return; ops.push_back(op); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -102,8 +102,8 @@ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Operand is written to if it has an aliasing OpResult. - auto bufferizableOp = cast(op); - return !bufferizableOp.getAliasingOpResult(opOperand, state).empty(); + auto dpsOp = cast(op); + return dpsOp.isDpsInit(&opOperand); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -57,7 +57,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -104,9 +104,9 @@ struct ExecuteRegionOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be // any SSA value that is in scope. To allow for use-def chain traversal // through ExecuteRegionOps in the analysis, the corresponding yield value @@ -120,7 +120,8 @@ auto yieldOp = dyn_cast( executeRegionOp.getRegion().front().getTerminator()); assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); - return {&yieldOp->getOpOperand(resultNum)}; + return {Certainty::MustBe, + {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}}; } // TODO: For better bufferization results, this could return `true` only if @@ -166,19 +167,14 @@ return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. struct IfOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { // IfOps do not have tensor OpOperands. The yielded value can be any SSA // value that is in scope. To allow for use-def chain traversal through // IfOps in the analysis, both corresponding yield values from the then/else @@ -186,8 +182,13 @@ auto ifOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); - return {&ifOp.thenYield()->getOpOperand(resultNum), - &ifOp.elseYield()->getOpOperand(resultNum)}; + OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); + OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); + bool equivalentYields = state.areEquivalentBufferizedValues( + thenOperand->get(), elseOperand->get()); + return {equivalentYields ? Certainty::MustBe : Certainty::MayBe, + {{thenOperand, BufferRelation::Equivalent}, + {elseOperand, BufferRelation::Equivalent}}}; } // TODO: For better bufferization results, this could return `true` only if @@ -301,19 +302,6 @@ return getMemRefTypeWithFullyDynamicLayout( opResult.getType().cast(), thenBufferType.getMemorySpace()); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // IfOp results are equivalent to their corresponding yield values if both - // yield values are equivalent to each other. - auto bufferizableOp = cast(op); - SmallVector yieldValues = - bufferizableOp.getAliasingOpOperand(opResult, state); - assert(yieldValues.size() == 2 && "expected 2 yield values"); - bool equivalentYields = state.areEquivalentBufferizedValues( - yieldValues[0]->get(), yieldValues[1]->get()); - return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None; - } }; /// Helper function for loop bufferization. Return the indices of all values @@ -485,10 +473,14 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); - return {forOp.getResultForOpOperand(opOperand)}; + OpResult opResult = forOp.getResultForOpOperand(opOperand); + BufferRelation relation = bufferRelation(op, opResult, state); + return {relation == BufferRelation::Equivalent ? Certainty::MustBe + : Certainty::MayBe, + {{opResult, relation}}}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -698,7 +690,7 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto whileOp = cast(op); unsigned int idx = opOperand.getOperandNumber(); @@ -710,7 +702,11 @@ return {}; // The only aliasing OpResult may be the one at the same index. - return {whileOp->getResult(idx)}; + OpResult opResult = whileOp->getResult(idx); + BufferRelation relation = bufferRelation(op, opResult, state); + return {relation == BufferRelation::Equivalent ? Certainty::MustBe + : Certainty::MayBe, + {{opResult, relation}}}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -993,12 +989,23 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (isa(op->getParentOp())) - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + if (auto ifOp = dyn_cast(op->getParentOp())) { + Region &otherCase = + ifOp->getRegion((op->getParentRegion()->getRegionNumber() + 1) % 2); + auto otherYieldOp = cast(otherCase.back().back()); + bool equivalentYields = state.areEquivalentBufferizedValues( + opOperand.get(), + otherYieldOp->getOperand(opOperand.getOperandNumber())); + return {equivalentYields ? Certainty::MustBe : Certainty::MayBe, + {{op->getParentOp()->getResult(opOperand.getOperandNumber()), + BufferRelation::Equivalent}}}; + } if (isa(op->getParentOp())) - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + return {Certainty::MustBe, + {{op->getParentOp()->getResult(opOperand.getOperandNumber()), + BufferRelation::Equivalent}}}; return {}; } @@ -1092,15 +1099,12 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto foreachThreadOp = cast(op); - return {foreachThreadOp.getTiedOpResult(&opOperand)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{foreachThreadOp.getTiedOpResult(&opOperand), + BufferRelation::Equivalent}}}; } bool isWritable(Operation *op, Value value, diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -27,9 +27,9 @@ struct AssumingOpInterface : public BufferizableOpInterface::ExternalModel { - SmallVector - getAliasingOpOperand(Operation *op, OpResult opResult, - const AnalysisState &state) const { + AliasingOpOperandList + getAliasingOpOperands(Operation *op, OpResult opResult, + const AnalysisState &state) const { // AssumingOps do not have tensor OpOperands. The yielded value can be any // SSA value that is in scope. To allow for use-def chain traversal through // AssumingOps in the analysis, the corresponding yield value is considered @@ -43,7 +43,8 @@ auto yieldOp = dyn_cast( assumingOp.getDoRegion().front().getTerminator()); assert(yieldOp && "expected shape.assuming_yield terminator"); - return {&yieldOp->getOpOperand(resultNum)}; + return {Certainty::MustBe, + {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}}; } // TODO: For better bufferization results, this could return `true` only if @@ -89,11 +90,6 @@ return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing @@ -111,11 +107,13 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { assert(isa(op->getParentOp()) && "expected that parent is an AssumingOp"); - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + OpResult opResult = + op->getParentOp()->getResult(opOperand.getOperandNumber()); + return {Certainty::MustBe, {{opResult, BufferRelation::Equivalent}}}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -43,7 +43,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -73,7 +73,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -97,14 +97,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{op->getOpResult(0), BufferRelation::Equivalent}}}; } }; @@ -136,18 +132,12 @@ return true; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // InsertOp returns an alias of its operand. assert(op->getNumResults() == 1); - return op->getResults(); - } - - BufferRelation bufferRelation(Operation *oo, OpResult opResult, - const AnalysisState &state) const { - // InsertOp returns the same object (realloc should not invalidate - // aliases). - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{op->getOpResult(0), BufferRelation::Equivalent}}}; } }; @@ -164,7 +154,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -185,7 +175,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -206,7 +196,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -227,7 +217,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,14 +41,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{op->getResult(0), BufferRelation::Equivalent}}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -98,16 +94,12 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(0) /*src*/) - return {op->getOpResult(0)}; - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + // CollapseShapeOp may allocate at runtime. + // TODO: Predict future buffer type and compute more accurate Certainty. + return {Certainty::MayBe, + {{op->getOpResult(0), BufferRelation::Equivalent}}}; } FailureOr @@ -214,7 +206,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -263,16 +255,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(0) /*src*/) - return {op->getOpResult(0)}; - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{op->getOpResult(0), BufferRelation::Equivalent}}}; } FailureOr @@ -324,16 +310,9 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(0) /*source*/) - return {op->getOpResult(0)}; - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::None; + return {Certainty::MustBe, {{op->getOpResult(0), BufferRelation::None}}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -397,7 +376,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -837,7 +816,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -935,7 +914,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -968,14 +947,10 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {Certainty::MustBe, + {{op->getOpResult(0), BufferRelation::Equivalent}}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -1000,9 +975,9 @@ struct ParallelInsertSliceOpInterface : public BufferizableOpInterface::ExternalModel< ParallelInsertSliceOpInterface, ParallelInsertSliceOp> { - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {}; + return {}; } bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,7 +41,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; } @@ -110,7 +110,7 @@ return false; } - SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {}; }