diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -385,13 +385,17 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", [TransformOpInterface, TransformEachOpTrait, FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { - let summary = "Applies the specified registered pass"; + let summary = "Applies the specified registered pass or pass pipeline"; let description = [{ - This transform applies the specified pass to the targeted ops. The name of - the pass is specified as a string attribute, as set during pass - registration. Optionally, pass options may be specified as a string - attribute. The pass options syntax is identical to the one used with - "mlir-opt". + This transform applies the specified pass or pass pipeline to the targeted + ops. The name of the pass/pipeline is specified as a string attribute, as + set during pass/pipeline registration. Optionally, pass options may be + specified as a string attribute. The pass options syntax is identical to the + one used with "mlir-opt". + + This op first looks for a pass pipeline with the specified name. If no such + pipeline exists, it looks for a pass with the specified name. If no such + pass exists either, this op fails definitely. This transform consumes the target handle and produces a new handle that is mapped to the same op. Passes are not allowed to remove/modify the operation diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -55,12 +55,9 @@ /// Returns the unique identifier that corresponds to this pass. TypeID getTypeID() const { return passID; } - /// Returns the pass info for the specified pass class or null if unknown. - static const PassInfo *lookupPassInfo(StringRef passArg); - /// Returns the pass info for this pass, or null if unknown. const PassInfo *lookupPassInfo() const { - return lookupPassInfo(getArgument()); + return PassInfo::lookup(getArgument()); } /// Returns the derived pass name. diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -105,6 +105,10 @@ std::function)> optHandler) : PassRegistryEntry(arg, description, builder, std::move(optHandler)) {} + + /// Returns the pass pipeline info for the specified pass pipeline or null if + /// unknown. + static const PassPipelineInfo *lookup(StringRef pipelineArg); }; /// A structure to represent the information for a derived pass class. @@ -114,6 +118,9 @@ /// PassRegistration or registerPass. PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator); + + /// Returns the pass info for the specified pass class or null if unknown. + static const PassInfo *lookup(StringRef passArg); }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -718,20 +718,23 @@ if (!payloadCheck.succeeded()) return payloadCheck; - // Get pass from registry. - const PassInfo *passInfo = Pass::lookupPassInfo(getPassName()); - if (!passInfo) { - return emitDefiniteFailure() << "unknown pass: " << getPassName(); - } + // Get pass or pass pipeline from registry. + const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); + if (!info) + info = PassInfo::lookup(getPassName()); + if (!info) + return emitDefiniteFailure() + << "unknown pass or pass pipeline: " << getPassName(); - // Create pass manager with a single pass and run it. + // Create pass manager and run the pass or pass pipeline. PassManager pm(getContext()); - if (failed(passInfo->addToPipeline(pm, getOptions(), [&](const Twine &msg) { + if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { emitError(msg); return failure(); }))) { return emitDefiniteFailure() - << "failed to add pass to pipeline: " << getPassName(); + << "failed to add pass or pass pipeline to pipeline: " + << getPassName(); } if (failed(pm.run(target))) { auto diag = emitSilenceableError() << "pass pipeline failed"; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -139,11 +139,18 @@ } /// Returns the pass info for the specified pass argument or null if unknown. -const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { +const PassInfo *mlir::PassInfo::lookup(StringRef passArg) { auto it = passRegistry->find(passArg); return it == passRegistry->end() ? nullptr : &it->second; } +/// Returns the pass pipeline info for the specified pass pipeline argument or +/// null if unknown. +const PassPipelineInfo *mlir::PassPipelineInfo::lookup(StringRef pipelineArg) { + auto it = passPipelineRegistry->find(pipelineArg); + return it == passPipelineRegistry->end() ? nullptr : &it->second; +} + //===----------------------------------------------------------------------===// // PassOptions //===----------------------------------------------------------------------===// @@ -653,16 +660,14 @@ // pipeline. if (!element.innerPipeline.empty()) return resolvePipelineElements(element.innerPipeline, errorHandler); + // Otherwise, this must be a pass or pass pipeline. // Check to see if a pipeline was registered with this name. - auto pipelineRegistryIt = passPipelineRegistry->find(element.name); - if (pipelineRegistryIt != passPipelineRegistry->end()) { - element.registryEntry = &pipelineRegistryIt->second; + if ((element.registryEntry = PassPipelineInfo::lookup(element.name))) return success(); - } // If not, then this must be a specific pass name. - if ((element.registryEntry = Pass::lookupPassInfo(element.name))) + if ((element.registryEntry = PassInfo::lookup(element.name))) return success(); // Emit an error for the unknown pass. diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -17,6 +17,21 @@ // ----- +// CHECK-LABEL: func @pass_pipeline( +func.func @pass_pipeline() { + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // This pipeline does not do anything. Just make sure that the pipeline is + // found and no error is produced. + transform.apply_registered_pass "test-options-pass-pipeline" to %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + func.func @invalid_pass_name() { return } @@ -24,7 +39,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{unknown pass: non-existing-pass}} + // expected-error @below {{unknown pass or pass pipeline: non-existing-pass}} transform.apply_registered_pass "non-existing-pass" to %1 : (!transform.any_op) -> !transform.any_op } @@ -54,7 +69,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - // expected-error @below {{failed to add pass to pipeline: canonicalize}} + // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}} transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op }