diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -25,7 +25,7 @@ class AnalysisState; class BufferizableOpInterface; -/// Specify fine-grain relationship between buffers to enable more analysis. +/// Specifies a fine-grain relationship between buffers to enable more analysis. enum class BufferRelation { Unknown, // TODO: ResultContainsOperand, @@ -33,13 +33,65 @@ Equivalent }; +/// A maybe aliasing OpOperand. If `isDefinite` is `true`, the OpOperand is +/// guaranteed to alias at runtime. +struct AliasingOpOperand { + AliasingOpOperand(OpOperand *opOperand, BufferRelation relation, + bool isDefinite = true) + : opOperand(opOperand), relation(relation), isDefinite(isDefinite) {} + + OpOperand *opOperand; + BufferRelation relation; + bool isDefinite; +}; + +/// A maybe aliasing OpResult. If `isDefinite` is `true`, the OpResult is +/// guaranteed to alias at runtime. +struct AliasingOpResult { + AliasingOpResult(OpResult opResult, BufferRelation relation, + bool isDefinite = true) + : opResult(opResult), relation(relation), isDefinite(isDefinite) {} + + OpResult opResult; + BufferRelation relation; + bool isDefinite; +}; + +template class AliasList { +public: + /// Create an empty list of aliases. + AliasList() = default; + + /// Create a list of aliases. + AliasList(std::initializer_list elems) { + for (T alias : elems) + addAlias(alias); + } + + /// Create a list of aliases. + AliasList(SmallVector &&aliases) : aliases(std::move(aliases)) {} + + ArrayRef getAliases() const { return aliases; } + + size_t getNumAliases() const { return aliases.size(); } + + void addAlias(T alias) { aliases.push_back(alias); } + + auto begin() const { return aliases.begin(); } + auto end() const { return aliases.end(); } + +private: + /// The list of aliases. + SmallVector aliases; +}; + /// A list of possible aliasing OpOperands. This list models the runtime /// aliasing relationship for an OpResult. -using AliasingOpOperandList = SmallVector; +using AliasingOpOperandList = AliasList; /// A list of possible aliasing OpResults. This list models the runtime /// aliasing relationship for an OpOperand. -using AliasingOpResultList = SmallVector; +using AliasingOpResultList = AliasList; class OpFilter { public: diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -180,31 +180,50 @@ OpOperand may at runtime alias with any (or multiple) of the returned OpResults. + Each alias is specified with a degree of certainty: + + * MAYBE (`isDefinite = false`): At runtime, buffer(opOperand) may + alias with the specified OpResult. + * DEFINITE (`isDefinite = true`, default): At runtime, + buffer(opOperand) is guaranteed to alias the buffer of the specified + OpResult. This is a stronger property than MAYBE and allows for more + precise analyses. DEFINITE properties should be used when possible. + + Furthermore, each alias is specified with a buffer relation: + + * `BufferRelation::Equivalent`: Both aliases are the exact same + buffer. I.e., same size, no offset, same strides. + * `BufferRelation::Unknown`: There is no further information apart + from the fact that both buffers alias. + False positives are allowed in the list of OpResults, but they can adversely affect the accuracy of the anlysis. On the contrary, omitting potential aliases is incorrect. One possible (conservative) implementation of this interface method, - that is always safe, is to return all tensor OpResults. + that is always safe, is to return all tensor OpResults with + BufferRelation::Unknown and MAYBE. Examples: ``` - // aliasingOpResults(%t) = {%r} + // aliasingOpResults(%t) = DEFINITE {Equivalent %r} %r = tensor.insert_slice %f into %t : tensor<10xf32> - // aliasingOpResults(%t) = {%r} + // aliasingOpResults(%t) = DEFINITE {Unknown %r} + // Note: "Buffer is subset of buffer" relationship are not yet + // supported, so "Unknown" is the best we can do for now. %r = tensor.extract_slice %t[0]][5][1] : tensor<10xf32> to tensor<5xf32> - // aliasingOpResults(%t1) = {%r} - // aliasingOpResults(%t2) = {%r} + // aliasingOpResults(%t1) = MAYBE {Equivalent %r} + // aliasingOpResults(%t2) = MAYBE {Equivalent %r} %r = arith.select %c, %t1, %t2 : tensor<10xf32> // A hypothetical op that bufferizes to rolling a dice and based on // the result to either return buffer(%t) or a newly allocated copy // thereof. - // aliasingOpResults(%t) = {%r} + // aliasingOpResults(%t) = MAYBE {Equivalent %r} %r = "dummy.alias_or_copy(%t) : (tensor<10xf32>) -> (tensor<10xf32>)" ``` }], @@ -234,12 +253,30 @@ OpResult may at runtime alias with any (or multiple) of the returned OpOperands. + This property is specified with a degree of certainty: + + * MAYBE (`isDefinite = false`): At runtime, buffer(opResult) may + alias with the specified OpOperand. + * DEFINITE (`isDefinite = true`, default): At runtime, + buffer(opResult) is guaranteed to alias the buffer of the specified + OpOperand. This is a stronger property than MAYBE and allows for + more precise analyses. DEFINITE properties should be used when + possible. + + For each alias, a BufferRelation can be specified: + + * `BufferRelation::Equivalent`: Both aliases are the exact same + buffer. I.e., same size, no offset, same strides. + * `BufferRelation::Unknown`: There is no further information apart + from the fact that both buffers alias. + False positives are allowed in the list of OpOperands, but they can adversely affect the accuracy of the anlysis. On the contrary, omitting potential aliases is incorrect. One possible (conservative) implementation of this interface method, - that is always safe, is to return all tensor OpOperands. + that is always safe, is to return all tensor OpOperands with + BufferRelation::Unknown and MAYBE. Note: If the returned list of OpOperands is empty, this op definitely bufferizes to a new allocation. In that case `bufferizesToAllocation` @@ -248,18 +285,19 @@ Examples: ``` - // aliasingOpOperands(%r) = {%t} + // aliasingOpOperands(%r) = DEFINITE {Equivalent %t} %r = tensor.insert_slice %f into %t : tensor<10xf32> - // aliasingOpOperands(%r) = {%t} + // aliasingOpOperands(%r) = DEFINITE {Unknown %t} %r = tensor.extract_slice %t[0]][5][1] : tensor<10xf32> to tensor<5xf32> - // aliasingOpOperands(%r) = {%t1, %t2} + // aliasingOpOperands(%r) = DEFINITE {Equivalent %t1, Equivalent %t2} %r = arith.select %c, %t1, %t2 : tensor<10xf32> - // aliasingOpOperands(%r) = {} - %r = te + // aliasingOpOperands(%r) = MAYBE {} + %r = tensor.empty() : tensor<10xf32> + ``` }], /*retType=*/"::mlir::bufferization::AliasingOpOperandList", /*methodName=*/"getAliasingOpOperands", @@ -273,28 +311,6 @@ opResult, state); }] >, - InterfaceMethod< - /*desc=*/[{ - Return the buffer relation between the given OpResult and its aliasing - OpOperands when bufferized in-place. Most OpOperands have an - "equivalence" relation. This method will never be called on OpResults - that do not have a tensor type. It will also never be called on - OpResults that do not have at least one aliasing OpOperand. - - TODO: Support other relations such as "OpOperand is included in - OpResult". - }], - /*retType=*/"::mlir::bufferization::BufferRelation", - /*methodName=*/"bufferRelation", - /*args=*/(ins "::mlir::OpResult":$opResult, - "const ::mlir::bufferization::AnalysisState &":$state), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Does not have to be implemented for ops without tensor OpResults - // that have an aliasing OpOperand. - llvm_unreachable("bufferRelation not implemented"); - }] - >, InterfaceMethod< /*desc=*/[{ Resolve all inplacability conflicts by inserting explicit @@ -489,7 +505,8 @@ cast<::mlir::bufferization::BufferizableOpInterface>(getOperation()); return !bufferizableOp.bufferizesToMemoryRead(opOperand, state) && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state) - && !bufferizableOp.getAliasingOpResults(opOperand, state).empty(); + && bufferizableOp.getAliasingOpResults(opOperand, state) + .getNumAliases() != 0; } }]; } diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h @@ -41,16 +41,9 @@ // Output operands alias with their respective tied OpResults. auto dstOp = cast(op); if (dstOp.isDpsInit(&opOperand)) - return {dstOp.getTiedOpResult(&opOperand)}; + return {{dstOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}; return {}; } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - assert(isa(op) && - "expected that op implements DestinationStyleOpInterface"); - return BufferRelation::Equivalent; - } }; } // namespace bufferization diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -76,12 +76,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{op->getResult(0), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -128,14 +123,8 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0) /*result*/}; - } - - AliasingOpOperandList - getAliasingOpOperands(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return {&op->getOpOperand(1) /*true_value*/, - &op->getOpOperand(2) /*false_value*/}; + return {{op->getOpResult(0) /*result*/, BufferRelation::Equivalent, + /*isDefinite=*/false}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -199,11 +188,6 @@ memrefType.getElementType()), memrefType.getMemorySpace()); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Unknown; - } }; } // namespace diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -203,25 +203,29 @@ // entirely? In either case, mark the allocation as "escaping", so that it // will not be deallocated. bool escape = !state.getOptions().createDeallocs || - llvm::any_of(aliasingOpResults, [&](Value v) { - return state.isTensorYielded(v); + llvm::any_of(aliasingOpResults, [&](AliasingOpResult a) { + return state.isTensorYielded(a.opResult); }); - if (aliasingOpResults.size() == 1 && + if (aliasingOpResults.getNumAliases() == 1 && !state.bufferizesToMemoryWrite(opOperand) && - state.getAliasingOpOperands(aliasingOpResults.front()).size() == 1 && - !aliasingOpResults.front().getType().isa()) { + state.getAliasingOpOperands(aliasingOpResults.getAliases()[0].opResult) + .getNumAliases() == 1 && + !aliasingOpResults.getAliases()[0] + .opResult.getType() + .isa()) { // The op itself does not write but may create exactly one alias. Instead // of copying the OpOperand, copy the OpResult. The OpResult can sometimes // be smaller than the OpOperand (e.g., in the case of an extract_slice, // where the result is usually a smaller part of the source). Do not apply // this optimization if the OpResult is an unranked tensor (because those // cannot be copied at the moment). - outOfPlaceOpResults.push_back(aliasingOpResults.front()); + OpResult opResult = aliasingOpResults.getAliases()[0].opResult; + outOfPlaceOpResults.push_back(opResult); if (!state.canOmitTensorCopy(opOperand)) - copiedOpResults.insert(aliasingOpResults.front()); + copiedOpResults.insert(opResult); if (escape) - escapingOpResultCopies.insert(aliasingOpResults.front()); + escapingOpResultCopies.insert(opResult); } else { // In all other cases, make a copy of the OpOperand. outOfPlaceOpOperands.push_back(&opOperand); @@ -456,8 +460,8 @@ OpOperand *uMaybeReading = workingSet.pop_back_val(); // Skip over all ops that neither read nor write (but create an alias). if (bufferizesToAliasOnly(*uMaybeReading)) - for (OpResult opResult : getAliasingOpResults(*uMaybeReading)) - for (OpOperand &use : opResult.getUses()) + for (AliasingOpResult alias : getAliasingOpResults(*uMaybeReading)) + for (OpOperand &use : alias.opResult.getUses()) workingSet.push_back(&use); if (bufferizesToMemoryRead(*uMaybeReading)) return true; @@ -491,19 +495,21 @@ // Stop iterating in either one of these cases: // * The current op is not bufferizable or excluded in the filter. // * There are no OpOperands to follow. - // * There is an OpOperand, but it is not an equivalent tensor (only if - // `followEquivalentOnly` is set). - if (!bufferizableOp || aliases.empty() || - (followEquivalentOnly && - bufferizableOp.bufferRelation(opResult, *this) != - BufferRelation::Equivalent)) { + if (!bufferizableOp || aliases.getNumAliases() == 0) { if (alwaysIncludeLeaves) result.insert(value); continue; } - for (OpOperand *o : aliases) - workingSet.insert(o->get()); + for (AliasingOpOperand a : aliases) { + if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) { + // Stop iterating if `followEquivalentOnly` is set but the alias is not + // equivalent. + result.insert(value); + } else { + workingSet.insert(a.opOperand->get()); + } + } } return result; @@ -539,8 +545,8 @@ // Do not copy if the tensor is never read. AliasingOpResultList aliases = getAliasingOpResults(opOperand); if (!bufferizesToMemoryRead(opOperand) && - llvm::none_of(aliases, - [&](OpResult opResult) { return isValueRead(opResult); })) + llvm::none_of( + aliases, [&](AliasingOpResult a) { return isValueRead(a.opResult); })) return true; // Default: Cannot omit the copy. @@ -608,8 +614,8 @@ // Note: In the absence of detailed analysis information (e.g., there may be // no function call analysis information), this `getAliasingOpResult` is // conservative and may report additional OpResults as potentially aliasing. - for (OpResult opResult : getAliasingOpResults(*operand)) - for (OpOperand &use : opResult.getUses()) + for (AliasingOpResult alias : getAliasingOpResults(*operand)) + for (OpOperand &use : alias.opResult.getUses()) worklist.push_back(&use); } @@ -841,13 +847,13 @@ // Case 1: OpResults that have no aliasing OpOperand usually bufferize to // memory writes. - if (opOperands.empty()) + if (opOperands.getAliases().empty()) return true; // Case 2: If an aliasing OpOperand bufferizes to a memory write, the OpResult // may bufferize to a memory write. - if (llvm::any_of(opOperands, [&](OpOperand *operand) { - return state.bufferizesToMemoryWrite(*operand); + if (llvm::any_of(opOperands, [&](AliasingOpOperand alias) { + return state.bufferizesToMemoryWrite(*alias.opOperand); })) return true; @@ -886,9 +892,9 @@ return false; return state.bufferizesToMemoryWrite(v); }; - for (OpOperand *operand : opOperands) { + for (AliasingOpOperand alias : opOperands) { if (!state - .findValueInReverseUseDefChain(operand->get(), + .findValueInReverseUseDefChain(alias.opOperand->get(), isMemoryWriteInsideOp, /*followEquivalentOnly=*/false) .empty()) @@ -902,16 +908,17 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( OpResult opResult, const AnalysisState &state) { Operation *op = opResult.getDefiningOp(); - AliasingOpOperandList result; + SmallVector result; for (OpOperand &opOperand : op->getOpOperands()) { if (!opOperand.get().getType().isa()) continue; AliasingOpResultList aliasingOpResults = state.getAliasingOpResults(opOperand); - if (llvm::is_contained(aliasingOpResults, opResult)) - result.push_back(&opOperand); + for (const auto &it : aliasingOpResults) + if (it.opResult == opResult) + result.emplace_back(&opOperand, it.relation, it.isDefinite); } - return result; + return AliasingOpOperandList(std::move(result)); } FailureOr bufferization::detail::defaultGetBufferType( @@ -926,14 +933,13 @@ // Value is an OpResult. Operation *op = getOwnerOfValue(value); auto opResult = value.cast(); - auto bufferizableOp = cast(op); AnalysisState state(options); AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); - if (!aliases.empty() && bufferizableOp.bufferRelation(opResult, state) == - BufferRelation::Equivalent) { + if (aliases.getNumAliases() > 0 && + aliases.getAliases()[0].relation == BufferRelation::Equivalent) { // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. - Value equivalentOperand = aliases.front()->get(); + Value equivalentOperand = aliases.getAliases().front().opOperand->get(); return getBufferType(equivalentOperand, options, fixedTypes); } @@ -958,20 +964,20 @@ AliasingOpOperandList bufferization::detail::unknownGetAliasingOpOperands(OpResult opResult) { - // Conservatively assume that everything is aliasing. + // Conservatively assume that everything may be aliasing. AliasingOpOperandList r; for (OpOperand &operand : opResult.getDefiningOp()->getOpOperands()) if (operand.get().getType().isa()) - r.push_back(&operand); + r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } AliasingOpResultList bufferization::detail::unknownGetAliasingOpResults(OpOperand &opOperand) { - // Conservatively assume that everything is aliasing. + // Conservatively assume that everything may be aliasing. AliasingOpResultList r; for (OpResult result : opOperand.getOwner()->getOpResults()) if (result.getType().isa()) - r.push_back(result); + r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -186,38 +186,23 @@ auto aliasingReturnVals = funcState.aliasingReturnVals.lookup(funcOp).lookup( opOperand.getOperandNumber()); - AliasingOpResultList result; - for (int64_t resultIdx : aliasingReturnVals) - result.push_back(callOp->getOpResult(resultIdx)); - return result; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - func::CallOp callOp = cast(op); - FuncOp funcOp = getCalledFunction(callOp); - assert(funcOp && "expected CallOp to a FuncOp"); - if (getFuncOpAnalysisState(state, funcOp) != - FuncOpAnalysisState::Analyzed) { - // Function not analyzed yet. The conservative answer is "None". - return BufferRelation::Unknown; - } - const FuncAnalysisState &funcState = getFuncAnalysisState(state); - std::optional maybeEquiv = - getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); - if (maybeEquiv) { -#ifndef NDEBUG - SmallVector aliasingOpOperands = - getAliasingOpOperands(op, opResult, state); - assert(aliasingOpOperands.size() == 1 && - "expected exactly 1 aliasing OpOperand"); - assert(aliasingOpOperands.front()->getOperandNumber() == *maybeEquiv && + // Check if the aliasing OpResult is equivalent to the OpOperand. + std::optional equivalent = {}; + if (aliasingReturnVals.size() == 1) { + equivalent = getEquivalentFuncArgIdx(funcOp, funcState, + aliasingReturnVals.front()); + assert((!equivalent.has_value() || + *equivalent == opOperand.getOperandNumber()) && "inconsistent analysis state"); -#endif - return BufferRelation::Equivalent; } - return BufferRelation::Unknown; + AliasingOpResultList result; + for (int64_t resultIdx : aliasingReturnVals) + result.addAlias({callOp->getOpResult(resultIdx), + equivalent.has_value() ? BufferRelation::Equivalent + : BufferRelation::Unknown, + /*isDefinite=*/equivalent.has_value()}); + return result; } /// All function arguments are writable. It is the responsibility of the diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -147,8 +147,8 @@ if (inplaceBufferized.contains(&operand)) return; markInPlace(operand); - for (OpResult result : state.getAliasingOpResults(operand)) - aliasInfo.unionSets(result, operand.get()); + for (AliasingOpResult alias : state.getAliasingOpResults(operand)) + aliasInfo.unionSets(alias.opResult, operand.get()); ++statNumTensorInPlace; } @@ -633,7 +633,8 @@ // use. AliasingOpResultList aliases = state.getAliasingOpResults(*uConflictingWrite); - if (aliases.size() == 1 && aliases[0] == definition) { + if (aliases.getNumAliases() == 1 && + aliases.getAliases()[0].opResult == definition) { LLVM_DEBUG(llvm::dbgs() << " no conflict: definition and write are same\n"); continue; @@ -692,9 +693,10 @@ // there would then be no flow of data from the extract_slice operand to // its result's uses.) if (!state.bufferizesToMemoryWrite(use)) { - AliasingOpResultList opResults = state.getAliasingOpResults(use); - if (llvm::any_of(opResults, - [&](OpResult r) { return state.isValueRead(r); })) + AliasingOpResultList aliases = state.getAliasingOpResults(use); + if (llvm::any_of(aliases, [&](AliasingOpResult a) { + return state.isValueRead(a.opResult); + })) res.insert(&use); } } @@ -738,9 +740,9 @@ DenseSet usesRead, usesWrite; getAliasingReads(usesRead, operand.get(), aliasInfo, state); getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); - for (OpResult result : state.getAliasingOpResults(operand)) { - getAliasingReads(usesRead, result, aliasInfo, state); - getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + for (AliasingOpResult alias : state.getAliasingOpResults(operand)) { + getAliasingReads(usesRead, alias.opResult, aliasInfo, state); + getAliasingInplaceWrites(usesWrite, alias.opResult, aliasInfo, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -791,11 +793,10 @@ continue; // Follow reverse SSA use-def chain. - AliasingOpOperandList aliasingOpOperands = - state.getAliasingOpOperands(opResult); - for (OpOperand *opOperand : aliasingOpOperands) - if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) - worklist.push_back(opOperand->get()); + for (AliasingOpOperand alias : state.getAliasingOpOperands(opResult)) + if (aliasInfo.isInPlace(*alias.opOperand) || + currentOpOperand == alias.opOperand) + worklist.push_back(alias.opOperand->get()); } return false; } @@ -808,8 +809,8 @@ // Collect writes of all aliases of OpOperand and OpResult. DenseSet usesWrite; getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); - for (OpResult result : state.getAliasingOpResults(operand)) { - getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + for (AliasingOpResult alias : state.getAliasingOpResults(operand)) { + getAliasingInplaceWrites(usesWrite, alias.opResult, aliasInfo, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -944,16 +945,42 @@ static void equivalenceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, AnalysisState &state) { - for (Operation *op : ops) - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) - for (OpResult opResult : op->getOpResults()) - if (opResult.getType().isa()) - for (OpOperand *opOperand : - bufferizableOp.getAliasingOpOperands(opResult, state)) - if (state.isInPlace(*opOperand)) - if (bufferizableOp.bufferRelation(opResult, state) == - BufferRelation::Equivalent) - aliasInfo.unionEquivalenceClasses(opResult, opOperand->get()); + for (Operation *op : ops) { + if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { + for (OpResult opResult : op->getOpResults()) { + if (!opResult.getType().isa()) + continue; + AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); + if (aliases.getNumAliases() == 0) + // Nothing to do if there are no aliasing OpOpeands. + continue; + + Value firstOperand = aliases.begin()->opOperand->get(); + bool allEquivalent = true; + for (AliasingOpOperand alias : aliases) { + bool isEquiv = alias.relation == BufferRelation::Equivalent; + bool isInPlace = state.isInPlace(*alias.opOperand); + Value operand = alias.opOperand->get(); + if (isEquiv && isInPlace && alias.isDefinite) { + // Found a definite, equivalent alias. Merge equivalence sets. + // There can only be one definite alias, so we can stop here. + aliasInfo.unionEquivalenceClasses(opResult, operand); + allEquivalent = false; + break; + } + if (!isEquiv || !isInPlace) + allEquivalent = false; + if (!state.areEquivalentBufferizedValues(operand, firstOperand)) + allEquivalent = false; + } + + // If all "maybe" aliases are equivalent and the OpResult is not a new + // allocation, it is a definite, equivalent alias. + if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult)) + aliasInfo.unionEquivalenceClasses(opResult, firstOperand); + } + } + } } /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -120,7 +120,7 @@ auto yieldOp = dyn_cast( executeRegionOp.getRegion().front().getTerminator()); assert(yieldOp && "expected scf.yield terminator in scf.execute_region"); - return {&yieldOp->getOpOperand(resultNum)}; + return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -154,11 +154,6 @@ return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. @@ -174,8 +169,10 @@ auto ifOp = cast(op); size_t resultNum = std::distance(op->getOpResults().begin(), llvm::find(op->getOpResults(), opResult)); - return {&ifOp.thenYield()->getOpOperand(resultNum), - &ifOp.elseYield()->getOpOperand(resultNum)}; + OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); + OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); + return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, + {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -258,20 +255,6 @@ return getMemRefTypeWithFullyDynamicLayout( opResult.getType().cast(), thenBufferType.getMemorySpace()); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - // IfOp results are equivalent to their corresponding yield values if both - // yield values are equivalent to each other. - auto bufferizableOp = cast(op); - SmallVector yieldValues = - bufferizableOp.getAliasingOpOperands(opResult, state); - assert(yieldValues.size() == 2 && "expected 2 yield values"); - bool equivalentYields = state.areEquivalentBufferizedValues( - yieldValues[0]->get(), yieldValues[1]->get()); - return equivalentYields ? BufferRelation::Equivalent - : BufferRelation::Unknown; - } }; /// Helper function for loop bufferization. Return the indices of all values @@ -446,7 +429,10 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto forOp = cast(op); - return {forOp.getResultForOpOperand(opOperand)}; + OpResult opResult = forOp.getResultForOpOperand(opOperand); + BufferRelation relation = bufferRelation(op, opResult, state); + return {{opResult, relation, + /*isDefinite=*/relation == BufferRelation::Equivalent}}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -669,7 +655,10 @@ return {}; // The only aliasing OpResult may be the one at the same index. - return {whileOp->getResult(idx)}; + OpResult opResult = whileOp->getResult(idx); + BufferRelation relation = bufferRelation(op, opResult, state); + return {{opResult, relation, + /*isDefinite=*/relation == BufferRelation::Equivalent}}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, @@ -954,10 +943,13 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (isa(op->getParentOp())) - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + if (auto ifOp = dyn_cast(op->getParentOp())) { + return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), + BufferRelation::Equivalent, /*isDefinite=*/false}}; + } if (isa(op->getParentOp())) - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), + BufferRelation::Equivalent}}; return {}; } @@ -1054,12 +1046,8 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto foreachThreadOp = cast(op); - return {foreachThreadOp.getTiedOpResult(&opOperand)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{{foreachThreadOp.getTiedOpResult(&opOperand), + BufferRelation::Equivalent}}}; } bool isWritable(Operation *op, Value value, diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -43,7 +43,7 @@ auto yieldOp = dyn_cast( assumingOp.getDoRegion().front().getTerminator()); assert(yieldOp && "expected shape.assuming_yield terminator"); - return {&yieldOp->getOpOperand(resultNum)}; + return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -77,11 +77,6 @@ return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing @@ -105,7 +100,7 @@ "expected that parent is an AssumingOp"); OpResult opResult = op->getParentOp()->getResult(opOperand.getOperandNumber()); - return {opResult}; + return {{opResult, BufferRelation::Equivalent}}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -99,12 +99,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{op->getOpResult(0), BufferRelation::Equivalent}}; } }; @@ -146,7 +141,7 @@ assert(isUniqueCOOType(op->getResultTypes()[0].cast())); // PackOp reuses the input tensors as data/indices instead of creating new // ones when packing into a COO format. - return op->getResults(); + return {{op->getOpResult(0), BufferRelation::Equivalent}}; } }; @@ -168,14 +163,7 @@ const AnalysisState &state) const { // InsertOp returns an alias of its operand. assert(op->getNumResults() == 1); - return op->getResults(); - } - - BufferRelation bufferRelation(Operation *oo, OpResult opResult, - const AnalysisState &state) const { - // InsertOp returns the same object (realloc should not invalidate - // aliases). - return BufferRelation::Equivalent; + return {{op->getOpResult(0), BufferRelation::Equivalent}}; } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -43,12 +43,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{op->getResult(0), BufferRelation::Equivalent}}; } FailureOr @@ -128,12 +123,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // TODO: CollapseShapeOp may allocate at runtime. - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{op->getOpResult(0), BufferRelation::Equivalent}}; } FailureOr @@ -297,12 +287,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{op->getOpResult(0), BufferRelation::Equivalent}}; } FailureOr @@ -356,14 +341,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(0) /*source*/) - return {op->getOpResult(0)}; - return {}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Unknown; + return {{op->getOpResult(0), BufferRelation::Unknown}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -1000,12 +978,7 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getOpResult(0)}; - } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; + return {{op->getOpResult(0), BufferRelation::Equivalent}}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -147,7 +147,7 @@ llvm::find(op->getOpResults(), opResult)); auto yieldOp = cast(maskOp.getMaskRegion().front().getTerminator()); - return {&yieldOp->getOpOperand(resultNum)}; + return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, @@ -217,11 +217,6 @@ replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues); return success(); } - - BufferRelation bufferRelation(Operation *op, OpResult opResult, - const AnalysisState &state) const { - return BufferRelation::Equivalent; - } }; /// Bufferization of vector.yield. Replaced with a new vector.yield that @@ -241,7 +236,8 @@ AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; + return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), + BufferRelation::Equivalent}}; } bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,