diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -156,6 +156,18 @@ /// The polymorphic API that runs the pass over the currently held operation. virtual void runOnOperation() = 0; + /// This is the equivalent to `runOnOperation` for a DynamicPass: it is + /// expected to return an OpPassManager containing the dynamic pipeline to + /// execute. + virtual void + renameDynamicPipeline(function_ref executor) { + llvm_unreachable("calling renameDynamicPipeline on a non-dynamic pass"); + } + + /// Return true if this is a dynamic pass: intended to be overriden only in + /// the DynamicPassWrapper class. + virtual bool isDynamic() const { return false; } + /// A clone method to create a copy of this pass. std::unique_ptr clone() const { auto newInst = clonePass(); @@ -379,6 +391,51 @@ std::unique_ptr clonePass() const override { return std::make_unique(*static_cast(this)); } + +private: + /// Hide from derived class, this is only available to dynamic passes. + void + renameDynamicPipeline(function_ref) final { + llvm_unreachable("calling renameDynamicPipeline on a non-dynamic pass"); + } + /// Always return false: use DynamicPassWrapper for dynamic passes. + bool isDynamic() const final { return false; } +}; + +/// This class provides a CRTP wrapper around a base dynamic pass class to +/// define several necessary utility methods. +template +class DynamicPassWrapper : public BaseT { +public: + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const Pass *pass) { + return pass->getTypeID() == TypeID::get(); + } + +protected: + DynamicPassWrapper() : BaseT(TypeID::get()) {} + + /// Returns the derived pass name. + StringRef getName() const override { return llvm::getTypeName(); } + + /// A clone method to create a copy of this pass. + std::unique_ptr clonePass() const override { + return std::make_unique(*static_cast(this)); + } + + /// Build and return a dynamic pipeline. + virtual void + renameDynamicPipeline(function_ref) = 0; + +private: + /// Hide this method from dynamic passes. + void runOnOperation() final { + llvm::report_fatal_error( + "Unexpected use of runOnOperation in a dynamic pass"); + } + + /// Always returns true to indicate that it is a dynamic pass. + bool isDynamic() const final { return true; } }; } // end namespace mlir diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -36,6 +36,7 @@ namespace detail { struct OpPassManagerImpl; +struct PassExecutionState; } // end namespace detail //===----------------------------------------------------------------------===// @@ -118,6 +119,7 @@ /// Allow access to the constructor. friend class PassManager; + friend class Pass; /// Allow access. friend detail::OpPassManagerImpl; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -343,8 +343,15 @@ if (pi) pi->runBeforePass(pass, op); - // Invoke the virtual runOnOperation method. - pass->runOnOperation(); + // Invoke the virtual either runOnOperation method or if a dynamic pass, + // retried the OpPassManager and execute it. + if (pass->isDynamic()) { + pass->renameDynamicPipeline([&](OpPassManager &pipeline) { + return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), op, am); + }); + } else { + pass->runOnOperation(); + } // Invalidate any non preserved analyses. am.invalidate(pass->passState->preservedAnalyses); diff --git a/mlir/test/Pass/dynamic_pipeline.mlir b/mlir/test/Pass/dynamic_pipeline.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/dynamic_pipeline.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{op-name=inner_mod1, dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD1 --check-prefix=MOD1-ONLY --check-prefix=CHECK +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{op-name=inner_mod2, dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD2 --check-prefix=MOD2-ONLY --check-prefix=CHECK +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{op-name=inner_mod1,inner_mod2, dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD1 --check-prefix=MOD2 --check-prefix=CHECK +// RUN: mlir-opt %s -pass-pipeline='module(test-dynamic-pipeline{dynamic-pipeline=func(cse,canonicalize)})' --mlir-disable-threading -print-ir-before-all 2>&1 | FileCheck %s --check-prefix=MOD1 --check-prefix=MOD2 --check-prefix=CHECK + + +func @f() { + return +} + +// CHECK: IR Dump Before +// CHECK-SAME: TestDynamicPipelinePass +// CHECK-NEXT: module @inner_mod1 +// MOD2-ONLY: dynamic-pipeline skip op name: inner_mod1 +module @inner_mod1 { +// MOD1: Dump Before CSE +// MOD1-NEXT: @foo +// MOD1: Dump Before Canonicalizer +// MOD1-NEXT: @foo + func @foo() { + return + } +// MOD1: Dump Before CSE +// MOD1-NEXT: @baz +// MOD1: Dump Before Canonicalizer +// MOD1-NEXT: @baz + func @baz() { + return + } +} + +// CHECK: IR Dump Before +// CHECK-SAME: TestDynamicPipelinePass +// CHECK-NEXT: module @inner_mod2 +// MOD1-ONLY: dynamic-pipeline skip op name: inner_mod2 +module @inner_mod2 { +// MOD2: Dump Before CSE +// MOD2-NEXT: @foo +// MOD2: Dump Before Canonicalizer +// MOD2-NEXT: @foo + func @foo() { + return + } +} diff --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp @@ -0,0 +1,74 @@ +//===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to test the dynamic pipeline feature. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { + +class TestDynamicPipelinePass + : public DynamicPassWrapper> { +public: + TestDynamicPipelinePass(){}; + TestDynamicPipelinePass(const TestDynamicPipelinePass &) {} + + void renameDynamicPipeline( + function_ref executor) override { + llvm::errs() << "Dynamic execute '" << pipeline << "' on " + << getOperation()->getName() << "\n"; + if (pipeline.empty()) + return; + auto symbolOp = dyn_cast(getOperation()); + if (!symbolOp) { + getOperation()->emitWarning() + << "Ignoring because not implementing SymbolOpInterface\n"; + return; + } + + auto opName = symbolOp.getName(); + if (!opNames.empty() && !llvm::is_contained(opNames, opName)) { + llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n"; + return; + } + if (!pm) { + pm = std::make_unique(getOperation()->getName(), false); + parsePassPipeline(pipeline, *pm, llvm::errs()); + } + if (failed(executor(*pm))) + signalPassFailure(); + } + + std::unique_ptr pm; + + Option pipeline{ + *this, "dynamic-pipeline", + llvm::cl::desc("The pipeline description that " + "will run on the filtered function.")}; + ListOption opNames{ + *this, "op-name", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("List of function name to apply the pipeline to")}; +}; +} // end namespace + +namespace mlir { +void registerTestDynamicPipelinePass() { + PassRegistration( + "test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying " + "a pipeline on a selected set of functions"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -49,6 +49,7 @@ void registerTestConvertGPUKernelToHsacoPass(); void registerTestDominancePass(); void registerTestDialect(DialectRegistry &); +void registerTestDynamicPipelinePass(); void registerTestExpandTanhPass(); void registerTestFunc(); void registerTestGpuMemoryPromotionPass(); @@ -100,6 +101,7 @@ #endif registerTestBufferPlacementPreparationPass(); registerTestDominancePass(); + registerTestDynamicPipelinePass(); registerTestFunc(); registerTestExpandTanhPass(); registerTestGpuMemoryPromotionPass();