diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -816,28 +816,37 @@ /// Starting from `value`, follow the use-def chain in reverse, always selecting /// the corresponding aliasing OpOperand. Try to find and return a Value for -/// which `condition` evaluates to true for the aliasing OpOperand. Return an -/// empty Value if no such Value was found. If `returnLast`, return the last -/// Value (at the end of the chain), even if it does not satisfy the condition. -static Value +/// which `condition` evaluates to true. +/// +/// When reaching the end of the chain (BlockArgument or Value without aliasing +/// OpOperands), return the last Value of the chain. +/// +/// Note: The returned SetVector contains exactly one element. +static llvm::SetVector findValueInReverseUseDefChain(Value value, - std::function condition, - bool returnLast = false) { - while (value.isa()) { - auto opResult = value.cast(); + std::function condition) { + llvm::SetVector result, workingSet; + workingSet.insert(value); + + while (!workingSet.empty()) { + Value value = workingSet.pop_back_val(); + if (condition(value) || value.isa()) { + result.insert(value); + continue; + } + + OpResult opResult = value.cast(); SmallVector opOperands = getAliasingOpOperand(opResult); - assert(opOperands.size() <= 1 && "more than 1 OpOperand not supported yet"); - if (opOperands.empty()) - // No aliasing OpOperand. This could be an unsupported op or an op without - // a tensor arg such as InitTensorOp. This is the end of the chain. - return returnLast ? value : Value(); - OpOperand *opOperand = opOperands.front(); - if (condition(*opOperand)) - return value; - value = opOperand->get(); + if (opOperands.empty()) { + result.insert(value); + continue; + } + + assert(opOperands.size() == 1 && "multiple OpOperands not supported yet"); + workingSet.insert(opOperands.front()->get()); } - // Value is a BlockArgument. Reached the end of the chain. - return returnLast ? value : Value(); + + return result; } /// Find the Value (result) of the last preceding write of a given Value. @@ -846,20 +855,41 @@ /// Furthermore, BlockArguments are also assumed to be writes. There is no /// analysis across block boundaries. static Value findLastPrecedingWrite(Value value) { - return findValueInReverseUseDefChain(value, bufferizesToMemoryWrite, true); + SetVector result = + findValueInReverseUseDefChain(value, [](Value value) { + Operation *op = value.getDefiningOp(); + if (!op) + return true; + if (!hasKnownBufferizationAliasingBehavior(op)) + return true; + + SmallVector opOperands = + getAliasingOpOperand(value.cast()); + assert(opOperands.size() <= 1 && + "op with multiple aliasing OpOperands not expected"); + + if (opOperands.empty()) + return true; + + return bufferizesToMemoryWrite(*opOperands.front()); + }); + assert(result.size() == 1 && "expected exactly one result"); + return result.front(); } /// Return true if `value` is originating from an ExtractSliceOp that matches /// the given InsertSliceOp. bool BufferizationAliasInfo::hasMatchingExtractSliceOp( Value value, InsertSliceOp insertOp) const { - return static_cast( - findValueInReverseUseDefChain(value, [&](OpOperand &opOperand) { - if (auto extractOp = dyn_cast(opOperand.getOwner())) - if (areEquivalentExtractSliceOps(extractOp, insertOp)) - return true; - return false; - })); + auto condition = [&](Value val) { + if (auto extractOp = val.getDefiningOp()) + if (areEquivalentExtractSliceOps(extractOp, insertOp)) + return true; + return false; + }; + + return llvm::all_of(findValueInReverseUseDefChain(value, condition), + condition); } /// Given sets of uses and writes, return true if there is a RaW conflict under