diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -247,9 +247,11 @@ /// Apply the transformation patterns in sequence with cleanup /// transformations interleaved. - LogicalResult transform(FuncOp func) const; void configurePassPipeline(OpPassManager &pm, MLIRContext *context) const; + /// Clean `funcOp` of CodegenStrategy-specific state. + void cleanup(FuncOp funcOp) const; + private: LogicalResult postPatternTransforms(Operation *func) const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -62,13 +62,9 @@ pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions)); } -LogicalResult mlir::linalg::CodegenStrategy::transform(FuncOp funcOp) const { - PassManager pm(funcOp.getContext(), funcOp.getOperationName()); - configurePassPipeline(pm, funcOp.getContext()); - LogicalResult res = pm.run(funcOp); - // Ensure we drop the marker in the end. +/// Clean `funcOp` of CodegenStrategy-specific state. +void mlir::linalg::CodegenStrategy::cleanup(FuncOp funcOp) const { funcOp.walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); - return res; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -162,7 +162,17 @@ .setVectorTransferSplit(vectorTransferSplit)) .setVectorTransferToSCFOptions( VectorTransferToSCFOptions().setUnroll(unrollVectorTransfers)); - (void)strategy.transform(getFunction()); + + // Created a nested OpPassManager and run. + FuncOp funcOp = getFunction(); + OpPassManager dynamicPM("builtin.func"); + strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + if (failed(runPipeline(dynamicPM, funcOp))) + return signalPassFailure(); + // Ensure we drop the marker in the end. + funcOp.walk([](LinalgOp op) { + op->removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); } } // end anonymous namespace