diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -73,6 +73,10 @@ std::unique_ptr createLinalgComprehensiveModuleBufferizePass( const bufferization::OneShotBufferizationOptions &options); +/// Create a pass that tries to eliminate init_tensor ops that are anchored on +/// insert_slice ops. +std::unique_ptr createLinalgInitTensorEliminationPass(); + /// Create a pass to convert Linalg operations which work on tensors to use /// buffers instead. std::unique_ptr> createLinalgBufferizePass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -62,10 +62,6 @@ Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned", /*default=*/"0", "Analyze ops in random order with a given seed (fuzzer)">, - Option<"initTensorElimination", "init-tensor-elimination", "bool", - /*default=*/"false", - "(Experimental) Try to eliminate init_tensor operations that are " - "anchored at an insert_slice op">, Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true", "Specify if buffers should be deallocated. For compatibility with " "core bufferization passes.">, @@ -73,6 +69,18 @@ let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()"; } +def LinalgInitTensorElimination : Pass<"linalg-eliminate-init-tensors"> { + let summary = "Try to eliminate all init_tensor ops."; + let description = [{ + This pass tries to eliminate all insert_slice op-anchored init_tensor ops. + I.e., when a value that is aliasing with an init_tensor op is inserted into + another tensor, this pass tries to rewrite the IR in such a way that the + destination tensor of the insert_slice op is used directly instead of the + init_tensor result. + }]; + let constructor = "mlir::createLinalgInitTensorEliminationPass()"; +} + def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> { let summary = "Remove unit-extent dimension in Linalg ops on tensors"; let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h @@ -33,21 +33,16 @@ /// on the reverse SSA use-def chain, starting from the OpOperand and always /// following the aliasing OpOperand, that eventually ends at a single /// InitTensorOp. -/// * The result of `rewriteFunc` must usually be analyzed for inplacability. -/// This analysis can be skipped with `skipAnalysis`. -LogicalResult -eliminateInitTensors(Operation *op, bufferization::AnalysisState &state, - bufferization::BufferizationAliasInfo &aliasInfo, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, - SmallVector &newOps); +LogicalResult eliminateInitTensors(RewriterBase &rewriter, Operation *op, + bufferization::AnalysisState &state, + AnchorMatchFn anchorMatchFunc, + RewriteFn rewriteFunc); /// Try to eliminate InitTensorOps inside `op` that are anchored on an /// InsertSliceOp, i.e., if it is eventually inserted into another tensor /// (and some other conditions are met). LogicalResult insertSliceAnchoredInitTensorEliminationStep( - Operation *op, bufferization::AnalysisState &state, - bufferization::BufferizationAliasInfo &aliasInfo, - SmallVector &newOps); + RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state); void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -334,16 +334,17 @@ /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. -LogicalResult mlir::linalg::eliminateInitTensors( - Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, - AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc, - SmallVector &newOps) { - OpBuilder b(op->getContext()); +LogicalResult mlir::linalg::eliminateInitTensors(RewriterBase &rewriter, + Operation *op, + AnalysisState &state, + AnchorMatchFn anchorMatchFunc, + RewriteFn rewriteFunc) { + OpBuilder::InsertionGuard g(rewriter); WalkResult status = op->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { // Skip operands that do not bufferize inplace. - if (!aliasInfo.isInPlace(operand)) + if (!state.isInPlace(operand)) continue; // All values that are needed to create the replacement op. SmallVector neededValues; @@ -359,14 +360,14 @@ SmallVector opOperands = state.getAliasingOpOperand(opResult); if (!llvm::all_of(opOperands, [&](OpOperand *operand) { - return aliasInfo.isInPlace(*operand); + return state.isInPlace(*operand); })) return true; // Only equivalent tensors are supported at the moment. // TODO: Support cases such as extract_slice(init_tensor) return !llvm::all_of(opOperands, [&](OpOperand *operand) { - return aliasInfo.areEquivalentBufferizedValues(operand->get(), - opResult); + return state.areEquivalentBufferizedValues(operand->get(), + opResult); }); }); @@ -384,21 +385,13 @@ continue; // Create a replacement for the InitTensorOp. - b.setInsertionPoint(insertionPoint); - Value replacement = rewriteFunc(b, initTensor.getLoc(), operand); + rewriter.setInsertionPoint(insertionPoint); + Value replacement = rewriteFunc(rewriter, initTensor.getLoc(), operand); if (!replacement) continue; - // Uses of the InitTensorOp are replaced here, but the op is not deleted. - // InitTensorOps without uses are ignored by the bufferization. - initTensor.replaceAllUsesWith(replacement); - aliasInfo.createAliasInfoEntry(replacement); - aliasInfo.unionAliasSets(initTensor, replacement); - aliasInfo.unionEquivalenceClasses(initTensor, replacement); - - // Register replacement ops. - if (Operation *newOp = replacement.getDefiningOp()) - newOps.push_back(newOp); + // Replace the InitTensorOp. + rewriter.replaceOp(initTensor.getDefiningOp(), replacement); } // Advance to the next operation. @@ -428,28 +421,20 @@ /// /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert /// source's reverse use-def chain is eliminated if: -/// * The InsertSliceOp was decided to bufferize inplace. /// * On the reverse use-def chain path from the InsertSliceOp to the /// InitTensorOp, all ops were decided to bufferize inplace and the buffer /// relation is "equivalent" (TODO: can be relaxed if needed). /// * The reverse use-def chain has exactly one end, which is the InitTensorOp. -/// -/// Note that the newly inserted ExtractSliceOp may have to bufferize -/// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep( - Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { + RewriterBase &rewriter, Operation *op, AnalysisState &state) { return eliminateInitTensors( - op, state, aliasInfo, + rewriter, op, state, /*anchorMatchFunc=*/ [&](OpOperand &operand, SmallVector &neededValues) { auto insertSliceOp = dyn_cast(operand.getOwner()); if (!insertSliceOp) return false; - // Only inplace bufferized InsertSliceOps are eligible. - if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/)) - return false; if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) return false; @@ -487,8 +472,7 @@ auto extractOp = b.create( loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides); return extractOp.result(); - }, - newOps); + }); } void mlir::linalg::registerBufferizableOpInterfaceExternalModels( diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -63,6 +63,17 @@ private: llvm::Optional options; }; + +struct LinalgInitTensorElimination + : public LinalgInitTensorEliminationBase { + LinalgInitTensorElimination() = default; + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } +}; } // namespace static void applyEnablingTransformations(ModuleOp moduleOp) { @@ -100,9 +111,6 @@ opt.testAnalysisOnly = testAnalysisOnly; opt.alwaysAliasingWithDest = alwaysAliasingWithDest; opt.bufferizeFunctionBoundaries = true; - if (initTensorElimination) { - opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep); - } } else { opt = *options; } @@ -125,6 +133,20 @@ (void)runPipeline(cleanupPipeline, moduleOp); } +void LinalgInitTensorElimination::runOnOperation() { + Operation *op = getOperation(); + OneShotBufferizationOptions options; + OneShotAnalysisState state(op, options); + if (failed(analyzeOp(op, state))) { + signalPassFailure(); + return; + } + + IRRewriter rewriter(op->getContext()); + if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state))) + signalPassFailure(); +} + std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { return std::make_unique(); } @@ -133,3 +155,7 @@ const OneShotBufferizationOptions &options) { return std::make_unique(options); } + +std::unique_ptr mlir::createLinalgInitTensorEliminationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir @@ -1,6 +1,4 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs init-tensor-elimination" -split-input-file | FileCheck %s - -// ----- +// RUN: mlir-opt %s -linalg-eliminate-init-tensors -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// // InitTensorOp elimination diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir @@ -1,6 +1,4 @@ -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s - -// ----- +// RUN: mlir-opt %s -linalg-eliminate-init-tensors -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" -canonicalize -split-input-file | FileCheck %s // CHECK: func @buffer_forwarding_conflict( // CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref