diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -195,6 +195,29 @@ /// Register external models implemented for the `BufferizableOpInterface`. void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); + +/// Try to eliminate InitTensorOps inside `funcOp`. +/// +/// * `rewriteFunc` generates the replacement for the InitTensorOp. +/// * Only InitTensorOps that are anchored on a matching OpOperand as per +/// `anchorMatchFunc` are considered. "Anchored" means that there is a path on +/// the reverse SSA use-def chain, starting from the OpOperand and always +/// following the aliasing OpOperand, that eventually ends at a single +/// InitTensorOp. +/// * The result of `rewriteFunc` must usually be analyzed for inplacability. +/// This analysis can be skipped with `skipAnalysis`. +LogicalResult initTensorElimination( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, + std::function anchorMatchFunc, + std::function rewriteFunc, + bool skipAnalysis = false); + +/// Try to eliminate InitTensorOps inside funcOp that are anchored on an +/// InsertSliceOp, i.e., if it is eventually inserted into another tensor +/// (and some other conditions are met). +LogicalResult eliminateInsertSliceAnchoredInitTensorOps( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo); + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2150,6 +2150,78 @@ } } +/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced +/// with the the result of `rewriteFunc` if it is anchored on a matching +/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def +/// chain, starting from the OpOperand and always following the aliasing +/// OpOperand, that eventually ends at a single InitTensorOp. +LogicalResult mlir::linalg::initTensorElimination( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, + std::function anchorMatchFunc, + std::function rewriteFunc, + bool skipAnalysis) { + OpBuilder b(funcOp->getContext()); + + WalkResult status = funcOp->walk([&](Operation *op) { + for (OpOperand &operand : op->getOpOperands()) { + // Is this a matching OpOperand? + if (!anchorMatchFunc(operand)) + continue; + + SetVector maybeInitTensor = + findValueInReverseUseDefChain(operand.get(), [](Value val) { + // Continue traversal until this function returns true. + OpResult opResult = val.dyn_cast(); + if (!opResult) + return true; + if (getInPlace(opResult) != InPlaceSpec::True) + return true; + // Only equivalent tensors are supported at the moment. + // TODO: Support cases such as extract_slice(init_tensor). + SmallVector opOperands = + getAliasingOpOperand(opResult); + if (!llvm::all_of(opOperands, [](OpOperand *operand) { + return bufferRelation(*operand) == BufferRelation::Equivalent; + })) + return true; + return false; + }); + + // Replace only if the reverse use-def chain ends at exactly one + // InitTensorOp. + if (maybeInitTensor.size() != 1 || + !maybeInitTensor.front().getDefiningOp()) + return WalkResult::skip(); + Value initTensor = maybeInitTensor.front(); + + // Create a replacement for the InitTensorOp. + b.setInsertionPoint(initTensor.getDefiningOp()); + Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); + if (!replacement) + continue; + + // Uses of the InitTensorOp are replaced here, but the op is not deleted. + // InitTensorOps without uses are ignored by the bufferization. + initTensor.replaceAllUsesWith(replacement); + aliasInfo.createAliasInfoEntry(replacement); + + // Run analysis on the newly created op. + if (auto opResult = replacement.dyn_cast()) { + if (!skipAnalysis) { + SmallVector ops(1, replacement.getDefiningOp()); + if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo))) + return WalkResult::interrupt(); + } + } + } + + // Advance to the next operation. + return WalkResult::advance(); + }); + + return failure(status.wasInterrupted()); +} + /// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be /// eliminated if it is eventually inserted into another tensor (and some other /// conditions are met). @@ -2178,60 +2250,26 @@ /// /// Note that the newly inserted ExtractSliceOp may have to bufferize /// out-of-place due to RaW conflicts. -static LogicalResult runInitTensorElimination(FuncOp funcOp, - BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo) { - OpBuilder b(funcOp->getContext()); - - WalkResult status = funcOp->walk([&](tensor::InsertSliceOp insertOp) { - // Only inplace bufferized InsertSliceOps are eligible. - if (getInPlace(insertOp->getOpResult(0)) != InPlaceSpec::True) - return WalkResult::skip(); - - SetVector maybeInitTensor = - findValueInReverseUseDefChain(insertOp.source(), [](Value val) { - // Continue traversal until this function returns true. - OpResult opResult = val.dyn_cast(); - if (!opResult) - return true; - if (getInPlace(opResult) != InPlaceSpec::True) - return true; - // Only equivalent tensors are supported at the moment. E.g., when - // taking a tensor.extract_slice of an init_tensor, we can currently - // not eliminate the init_tensor. - SmallVector opOperands = getAliasingOpOperand(opResult); - if (!llvm::all_of(opOperands, [](OpOperand *operand) { - return bufferRelation(*operand) == BufferRelation::Equivalent; - })) - return true; +LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo) { + return initTensorElimination( + funcOp, aliasInfo, domInfo, + [](OpOperand &operand) { + auto insertSliceOp = dyn_cast(operand.getOwner()); + if (!insertSliceOp) return false; - }); - // Replace only if the InsertSliceOp source originates from exactly one - // InitTensorOp. - if (maybeInitTensor.size() != 1 || - !maybeInitTensor.front().getDefiningOp()) - return WalkResult::skip(); - Value initTensor = maybeInitTensor.front(); - - b.setInsertionPoint(initTensor.getDefiningOp()); - auto extractOp = b.create( - initTensor.getLoc(), insertOp.dest(), insertOp.getMixedOffsets(), - insertOp.getMixedSizes(), insertOp.getMixedStrides()); - // Uses of the InitTensorOp are replaced here, but the op is not deleted. - // InitTensorOps without uses are ignored by the bufferization. - initTensor.replaceAllUsesWith(extractOp.result()); - aliasInfo.createAliasInfoEntry(extractOp.result()); - - // Run analysis on the ExtractSliceOp. - if (failed(bufferizableInPlaceAnalysisAliasOnlyOp( - extractOp->getOpOperand(0), aliasInfo, domInfo))) - return WalkResult::interrupt(); - - // Advance to the next operation. - return WalkResult::advance(); - }); - - return failure(status.wasInterrupted()); + // Only inplace bufferized InsertSliceOps are eligible. + if (getInPlace(insertSliceOp->getOpResult(0)) != InPlaceSpec::True) + return false; + return &operand == &insertSliceOp->getOpOperand(0) /*source*/; + }, + [](OpBuilder &b, Location loc, OpOperand &operand) { + auto insertSliceOp = cast(operand.getOwner()); + auto extractOp = b.create( + loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + return extractOp.result(); + }); } void LinalgComprehensiveModuleBufferize::runOnOperation() { @@ -2291,7 +2329,8 @@ // Try to eliminate InitTensorOps to avoid new allocations during the // bufferization phase. - if (failed(runInitTensorElimination(funcOp, aliasInfo, domInfo))) { + if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo, + domInfo))) { signalPassFailure(); return; }