diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -257,7 +257,8 @@ /// Apply the transformation patterns in sequence with cleanup /// transformations interleaved. - void configurePassPipeline(OpPassManager &pm, MLIRContext *context) const; + void configurePassPipeline(OpPassManager &pm, MLIRContext *context, + bool addEnablePass = true) const; private: LogicalResult postPatternTransforms(Operation *func) const; diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -29,7 +29,7 @@ #define DEBUG_TYPE "linalg-codegen-strategy" void mlir::linalg::CodegenStrategy::configurePassPipeline( - OpPassManager &pm, MLIRContext *context) const { + OpPassManager &pm, MLIRContext *context, bool addEnablePass) const { for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e; ++stepCount) { const std::unique_ptr &t = @@ -44,7 +44,8 @@ : linalg::LinalgTransformationFilter( t->filter, currentState, nextState); t->addToPassPipeline(pm, filter); - pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions)); + if (addEnablePass) + pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions)); } pm.addPass(createLinalgStrategyRemoveMarkersPass()); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -125,6 +125,10 @@ *this, "unroll-vector-transfers", llvm::cl::desc("Enable full unrolling of vector.transfer operations"), llvm::cl::init(false)}; + Option runEnablePass{ + *this, "run-enable-pass", + llvm::cl::desc("Run the enable pass between transformations"), + llvm::cl::init(true)}; Option anchorOpName{ *this, "anchor-op", llvm::cl::desc( @@ -178,11 +182,10 @@ .enableTransferPartialRewrite() .enableContractionLowering() .enableTransferToSCFConversion()); - // Created a nested OpPassManager and run. FuncOp funcOp = getFunction(); OpPassManager dynamicPM("builtin.func"); - strategy.configurePassPipeline(dynamicPM, funcOp.getContext()); + strategy.configurePassPipeline(dynamicPM, funcOp.getContext(), runEnablePass); if (failed(runPipeline(dynamicPM, funcOp))) return signalPassFailure(); }