diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -26,20 +26,51 @@ using namespace mlir; using namespace mlir::bufferization; +/// Return true if "a" is an ancestor of "b". +static bool isAncestorBlock(Block *a, Block *b) { + return a->getParentOp()->isAncestor(b->getParentOp()); +} + +/// Find and return the location of an operation inside "a" that is an ancestor +/// of "b" within "a". +static std::optional +findCommonAncestorBlock(Block *a, std::pair b) { + Block *nextBlock = b.first; + Block::iterator it = b.second; + while (true) { + if (nextBlock == a) + return it; + Operation *op = nextBlock->getParentOp(); + if (!op->getBlock()) + return {}; + it = op->getIterator(); + nextBlock = op->getBlock(); + } + llvm_unreachable("loop should have returned"); +} + /// Return true if all `neededValues` are in scope at the given /// `insertionPoint`. static bool -neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, - Operation *insertionPoint, +neededValuesDominateInsertionPoint(std::pair ip, const SmallVector &neededValues) { for (Value val : neededValues) { if (auto bbArg = val.dyn_cast()) { - Block *owner = bbArg.getOwner(); - if (!owner->findAncestorOpInBlock(*insertionPoint)) + // The block argument must be defined by a block that is an ancestor of + // the insertion block. (Note: Equal blocks are also ancestors.) + if (!isAncestorBlock(bbArg.getOwner(), ip.first)) return false; } else { - auto opResult = val.cast(); - if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) + // The defining op of the value must be in a block that is an ancestor of + // the insertion block. + Block *block = val.getDefiningOp()->getBlock(); + std::optional commonAncestor = + findCommonAncestorBlock(block, ip); + if (!commonAncestor.has_value()) + return false; + // Furthermore, the defining op must appear before insertion point. + if (std::distance(block->begin(), val.getDefiningOp()->getIterator()) >= + std::distance(block->begin(), *commonAncestor)) return false; } } @@ -48,55 +79,54 @@ /// Return true if the given `insertionPoint` dominates all uses of /// `emptyTensorOp`. -static bool insertionPointDominatesUses(const DominanceInfo &domInfo, - Operation *insertionPoint, +static bool insertionPointDominatesUses(std::pair ip, Operation *emptyTensorOp) { - for (Operation *user : emptyTensorOp->getUsers()) - if (!domInfo.dominates(insertionPoint, user)) + Block *block = ip.first; + for (Operation *user : emptyTensorOp->getUsers()) { + std::optional commonAncestor = findCommonAncestorBlock( + block, std::make_pair(user->getBlock(), user->getIterator())); + if (!commonAncestor.has_value()) return false; + if (std::distance(block->begin(), ip.second) >= + std::distance(block->begin(), *commonAncestor)) + return false; + } return true; } /// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming /// that the replacement may use any value from `neededValues`. -static Operation * +static FailureOr> findValidInsertionPoint(Operation *emptyTensorOp, const SmallVector &neededValues) { - DominanceInfo domInfo; - // Gather all possible insertion points: the location of `emptyTensorOp` and // right after the definition of each value in `neededValues`. - SmallVector insertionPointCandidates; - insertionPointCandidates.push_back(emptyTensorOp); + SmallVector> insertionPointCandidates; + insertionPointCandidates.emplace_back(emptyTensorOp->getBlock(), + emptyTensorOp->getIterator()); for (Value val : neededValues) { - // Note: The anchor op is using all of `neededValues`, so: - // * in case of a block argument: There must be at least one op in the block - // (the anchor op or one of its parents). - // * in case of an OpResult: There must be at least one op right after the - // defining op (the anchor op or one of its - // parents). if (auto bbArg = val.dyn_cast()) { - insertionPointCandidates.push_back( - &bbArg.getOwner()->getOperations().front()); + insertionPointCandidates.emplace_back(bbArg.getOwner(), + bbArg.getOwner()->begin()); } else { - insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); + Operation *op = val.getDefiningOp(); + insertionPointCandidates.emplace_back(op->getBlock(), op->getIterator()); } } // Select first matching insertion point. - for (Operation *insertionPoint : insertionPointCandidates) { + for (auto ip : insertionPointCandidates) { // Check if all needed values are in scope. - if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, - neededValues)) + if (!neededValuesDominateInsertionPoint(ip, neededValues)) continue; // Check if the insertion point is before all uses. - if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp)) + if (!insertionPointDominatesUses(ip, emptyTensorOp)) continue; - return insertionPoint; + return ip; } // No suitable insertion point was found. - return nullptr; + return failure(); } /// Try to eliminate tensor::EmptyOps inside `op`. A tensor::EmptyOp is replaced @@ -151,12 +181,11 @@ // Find a suitable insertion point. If no suitable insertion point for // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, neededValues); - if (!insertionPoint) + auto ip = findValidInsertionPoint(emptyTensorOp, neededValues); + if (failed(ip)) continue; - rewriter.setInsertionPoint(insertionPoint); + rewriter.setInsertionPoint(ip->first, ip->second); Value replacement = rewriteFunc(rewriter, emptyTensorOp->getLoc(), operand); if (!replacement) @@ -244,6 +273,28 @@ return success(); } +LogicalResult allocTensorAnchoredEmptyTensorEliminationStep( + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { + return eliminateEmptyTensors( + rewriter, op, state, + /*anchorMatchFunc=*/ + [&](OpOperand &operand, SmallVector &neededValues) { + auto yieldOp = dyn_cast(operand.getOwner()); + if (!yieldOp) + return false; + neededValues.push_back( + cast(yieldOp->getParentOp()) + .getBlockArgument()); + return true; + }, + /*rewriteFunc=*/ + [](OpBuilder &b, Location loc, OpOperand &operand) { + return cast( + operand.getOwner()->getParentOp()) + .getBlockArgument(); + }); +} + namespace { struct EmptyTensorElimination : public bufferization::impl::EmptyTensorEliminationBase< @@ -272,7 +323,10 @@ IRRewriter rewriter(op->getContext()); if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( rewriter, op, state))) - signalPassFailure(); + return signalPassFailure(); + if (failed( + allocTensorAnchoredEmptyTensorEliminationStep(rewriter, op, state))) + return signalPassFailure(); } std::unique_ptr mlir::bufferization::createEmptyTensorEliminationPass() { diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir @@ -47,3 +47,43 @@ // CHECK-SAME: __equivalent_func_args__ = [0, 0] return %2, %2 : tensor, tensor } + +// ----- + +// CHECK-LABEL: func @alloc_tensor_anchored_elimination( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32> +func.func @alloc_tensor_anchored_elimination(%t: tensor<5xf32>, %f: f32) -> tensor<10xf32> { + // CHECK: bufferization.alloc_tensor() init { + // CHECK-NEXT: ^{{.*}}(%[[bbarg:.*]]: tensor<10xf32>): + // CHECK-NEXT: %[[fill:.*]] = linalg.fill {{.*}} outs(%[[bbarg]] : tensor<10xf32>) + // CHECK-NEXT: %[[inserted:.*]] = tensor.insert_slice %[[t]] into %[[fill]] + // CHECK-NEXT: bufferization.yield %[[inserted]] + %0 = bufferization.alloc_tensor() init { + ^bb0(%arg0: tensor<10xf32>): + %1 = tensor.empty() : tensor<10xf32> + %2 = linalg.fill ins(%f : f32) outs(%1 : tensor<10xf32>) -> tensor<10xf32> + %3 = tensor.insert_slice %t into %2[2][5][1] : tensor<5xf32> into tensor<10xf32> + bufferization.yield %3 : tensor<10xf32> + } : tensor<10xf32> + return %0 : tensor<10xf32> +} + +// ----- + +// The tensor.empty cannot be eliminated in this example because there is no +// valid insertion point. + +// CHECK-LABEL: func @alloc_tensor_no_valid_insertion_point( +// CHECK-SAME: %[[t:.*]]: tensor<5xf32> +func.func @alloc_tensor_no_valid_insertion_point(%t: tensor<5xf32>, %f: f32) -> tensor<10xf32> { + // CHECK: bufferization.alloc_tensor() : tensor<10xf32> + // CHECK: bufferization.alloc_tensor() init { + %1 = tensor.empty() : tensor<10xf32> + %2 = linalg.fill ins(%f : f32) outs(%1 : tensor<10xf32>) -> tensor<10xf32> + %0 = bufferization.alloc_tensor() init { + ^bb0(%arg0: tensor<10xf32>): + %3 = tensor.insert_slice %t into %2[2][5][1] : tensor<5xf32> into tensor<10xf32> + bufferization.yield %3 : tensor<10xf32> + } : tensor<10xf32> + return %0 : tensor<10xf32> +}