diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -127,6 +127,10 @@ linalg::LinalgVectorLoweringOptions(), linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyRemoveMarkersPass. +std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -315,4 +315,15 @@ ]; } +def LinalgStrategyRemoveMarkersPass + : FunctionPass<"linalg-strategy-remove-markers-pass"> { + let summary = "Cleanup pass that drops markers."; + let constructor = "mlir::createLinalgStrategyRemoveMarkersPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -247,7 +247,6 @@ /// Apply the transformation patterns in sequence with cleanup /// transformations interleaved. - LogicalResult transform(FuncOp func) const; void configurePassPipeline(OpPassManager &pm, MLIRContext *context) const; private: diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -60,15 +60,5 @@ vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions; vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions; pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions)); -} - -LogicalResult mlir::linalg::CodegenStrategy::transform(FuncOp funcOp) const { - PassManager pm(funcOp.getContext(), funcOp.getOperationName()); - configurePassPipeline(pm, funcOp.getContext()); - LogicalResult res = pm.run(funcOp); - // Ensure we drop the marker in the end. - funcOp.walk([](LinalgOp op) { - op->removeAttr(LinalgTransforms::kLinalgTransformMarker); - }); - return res; + pm.addPass(createLinalgStrategyRemoveMarkersPass()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -287,6 +287,21 @@ LinalgVectorLoweringOptions options; LinalgTransformationFilter filter; }; + +/// Configurable pass to lower vector operations. +struct LinalgStrategyRemoveMarkersPass + : public LinalgStrategyRemoveMarkersPassBase< + LinalgStrategyRemoveMarkersPass> { + + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + funcOp.walk([](LinalgOp op) { + op->removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); + } +}; } // namespace /// Create a LinalgStrategyTilePass. @@ -340,3 +355,9 @@ LinalgTransformationFilter filter) { return std::make_unique(opt, filter); } + +/// Create a LinalgStrategyRemoveMarkersPass. +std::unique_ptr> +mlir::createLinalgStrategyRemoveMarkersPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -162,7 +162,13 @@ .setVectorTransferSplit(vectorTransferSplit)) .setVectorTransferToSCFOptions( VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); - (void)strategy.transform(getFunction()); + + // Created a nested OpPassManager and run. + FuncOp funcOp = getFunction(); + OpPassManager dynamicPM("builtin.func"); + strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + if (failed(runPipeline(dynamicPM, funcOp))) + return signalPassFailure(); } } // end anonymous namespace