diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -24,8 +24,11 @@ /// The state for a single execution of a pass. This provides a unified /// interface for accessing and initializing necessary state for pass execution. struct PassExecutionState { - PassExecutionState(Operation *ir, AnalysisManager analysisManager) - : irAndPassFailed(ir, false), analysisManager(analysisManager) {} + PassExecutionState( + Operation *ir, AnalysisManager analysisManager, + function_ref pipelineExecutor) + : irAndPassFailed(ir, false), analysisManager(analysisManager), + pipelineExecutor(pipelineExecutor) {} /// The current operation being transformed and a bool for if the pass /// signaled a failure. @@ -36,6 +39,10 @@ /// The set of preserved analyses for the current execution. detail::PreservedAnalyses preservedAnalyses; + + /// This is a callback in the PassManager that allows to schedule dynamic + /// pipelines that will be rooted in the currently visited operation. + function_ref pipelineExecutor; }; } // namespace detail @@ -156,6 +163,12 @@ /// The polymorphic API that runs the pass over the currently held operation. virtual void runOnOperation() = 0; + /// Schedule an arbitrary pass pipeline on the current operation. + /// This can be invoke any time in a pass to dynamic schedule more passes. + LogicalResult runPipeline(OpPassManager &pipeline) { + return passState->pipelineExecutor(pipeline); + } + /// A clone method to create a copy of this pass. std::unique_ptr clone() const { auto newInst = clonePass(); diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -36,6 +36,7 @@ namespace detail { struct OpPassManagerImpl; +struct PassExecutionState; } // end namespace detail //===----------------------------------------------------------------------===// @@ -118,6 +119,7 @@ /// Allow access to the constructor. friend class PassManager; + friend class Pass; /// Allow access. friend detail::OpPassManagerImpl; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -336,7 +336,11 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, AnalysisManager am) { - pass->passState.emplace(op, am); + // Initialize the pass state with a callback for the pass to dynamically + // execute a pipeline on the currently visited operation. + pass->passState.emplace(op, am, [&](OpPassManager &pipeline) { + return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), op, am); + }); // Instrument before the pass has run. PassInstrumentor *pi = am.getPassInstrumentor(); diff --git a/mlir/test/Pass/dynamic_pipeline.mlir b/mlir/test/Pass/dynamic_pipeline.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/dynamic_pipeline.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{op-name=inner_mod1, dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD1 --check-prefix=MOD1-ONLY --check-prefix=CHECK +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{op-name=inner_mod2, dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD2 --check-prefix=MOD2-ONLY --check-prefix=CHECK +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{op-name=inner_mod1,inner_mod2, dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD1 --check-prefix=MOD2 --check-prefix=CHECK +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD1 --check-prefix=MOD2 --check-prefix=CHECK + + +func @f() { + return +} + +// CHECK: IR Dump Before +// CHECK-SAME: TestDynamicPipelinePass +// CHECK-NEXT: module @inner_mod1 +// MOD2-ONLY: dynamic-pipeline skip op name: inner_mod1 +module @inner_mod1 { +// MOD1: Dump Before CSE +// MOD1-NEXT: @foo +// MOD1: Dump Before Canonicalizer +// MOD1-NEXT: @foo + func @foo() { + return + } +// MOD1: Dump Before CSE +// MOD1-NEXT: @baz +// MOD1: Dump Before Canonicalizer +// MOD1-NEXT: @baz + func @baz() { + return + } +} + +// CHECK: IR Dump Before +// CHECK-SAME: TestDynamicPipelinePass +// CHECK-NEXT: module @inner_mod2 +// MOD1-ONLY: dynamic-pipeline skip op name: inner_mod2 +module @inner_mod2 { +// MOD2: Dump Before CSE +// MOD2-NEXT: @foo +// MOD2: Dump Before Canonicalizer +// MOD2-NEXT: @foo + func @foo() { + return + } +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -9,6 +9,7 @@ TestConvertGPUKernelToCubin.cpp TestConvertGPUKernelToHsaco.cpp TestDominance.cpp + TestDynamicPipeline.cpp TestLoopFusion.cpp TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp diff --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp @@ -0,0 +1,73 @@ +//===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the dynamic pipeline feature. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { + +class TestDynamicPipelinePass + : public PassWrapper> { +public: + TestDynamicPipelinePass(){}; + TestDynamicPipelinePass(const TestDynamicPipelinePass &) {} + + void runOnOperation() override { + llvm::errs() << "Dynamic execute '" << pipeline << "' on " + << getOperation()->getName() << "\n"; + if (pipeline.empty()) + return; + auto symbolOp = dyn_cast(getOperation()); + if (!symbolOp) { + getOperation()->emitWarning() + << "Ignoring because not implementing SymbolOpInterface\n"; + return; + } + + auto opName = symbolOp.getName(); + if (!opNames.empty() && !llvm::is_contained(opNames, opName)) { + llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n"; + return; + } + if (!pm) { + pm = std::make_unique(getOperation()->getName(), false); + parsePassPipeline(pipeline, *pm, llvm::errs()); + } + if (failed(runPipeline(*pm))) + signalPassFailure(); + } + + std::unique_ptr pm; + + Option pipeline{ + *this, "dynamic-pipeline", + llvm::cl::desc("The pipeline description that " + "will run on the filtered function.")}; + ListOption opNames{ + *this, "op-name", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("List of function name to apply the pipeline to")}; +}; +} // end namespace + +namespace mlir { +void registerTestDynamicPipelinePass() { + PassRegistration( + "test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying " + "a pipeline on a selected set of functions"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -49,6 +49,7 @@ void registerTestConvertGPUKernelToHsacoPass(); void registerTestDominancePass(); void registerTestDialect(DialectRegistry &); +void registerTestDynamicPipelinePass(); void registerTestExpandTanhPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); @@ -100,6 +101,7 @@ #endif registerTestBufferPlacementPreparationPass(); registerTestDominancePass(); + registerTestDynamicPipelinePass(); registerTestFunc(); registerTestExpandTanhPass(); registerTestGpuMemoryPromotionPass();