diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -194,6 +194,40 @@ let assemblyFormat = "attr-dict"; } +def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", + [TransformOpInterface, TransformEachOpTrait, + FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { + let summary = "Applies the specified registered pass"; + let description = [{ + This transform applies the specified pass to the targeted ops. The name of + the pass is specified as a string attribute, as set during pass + registration. Optionally, pass options may be specified as a string + attribute. The pass options syntax is identical to the one used with + "mlir-opt". + + This transform consumes the target handle and produces a new handle that is + mapped to the same op. Passes are not allowed to remove/modify the operation + that they operate on, so the target op is guaranteed to still exist. The + target handle is invalidated because a pass may arbitrarily modify the body + of targeted ops. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + StrAttr:$pass_name, + DefaultValuedAttr:$options); + let results = (outs TransformHandleTypeInterface:$result); + let assemblyFormat = [{ + $pass_name `to` $target attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def CastOp : TransformDialectOp<"cast", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -15,6 +15,7 @@ MLIRCastInterfaces MLIRIR MLIRParser + MLIRPass MLIRRewrite MLIRSideEffectInterfaces MLIRTransforms diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -16,6 +16,9 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -47,6 +50,26 @@ ArrayAttr &matchers, ArrayAttr &actions); +/// Helper function to check if the given transform op is contained in (or +/// equal to) the given payload target op. In that case, an error is returned. +/// Transforming transform IR that is currently executing is generally unsafe. +static DiagnosedSilenceableFailure +ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, + Operation *payload) { + Operation *transformAncestor = transform.getOperation(); + while (transformAncestor) { + if (transformAncestor == payload) { + DiagnosedDefiniteFailure diag = + transform.emitDefiniteFailure() + << "cannot apply transform to itself (or one of its ancestors)"; + diag.attachNote(payload->getLoc()) << "target payload op"; + return diag; + } + transformAncestor = transformAncestor->getParentOp(); + } + return DiagnosedSilenceableFailure::success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" @@ -243,17 +266,10 @@ // transform IR while it is being interpreted is generally dangerous. Even // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver // performs many additional simplifications such as dead code elimination. - Operation *transformAncestor = getOperation(); - while (transformAncestor) { - if (transformAncestor == target) { - DiagnosedDefiniteFailure diag = - emitDefiniteFailure() - << "cannot apply transform to itself (or one of its ancestors)"; - diag.attachNote(target->getLoc()) << "target payload op"; - return diag; - } - transformAncestor = transformAncestor->getParentOp(); - } + DiagnosedSilenceableFailure payloadCheck = + ensurePayloadIsSeparateFromTransform(*this, target); + if (!payloadCheck.succeeded()) + return payloadCheck; // Gather all specified patterns. MLIRContext *ctx = target->getContext(); @@ -357,6 +373,48 @@ op.getCanonicalizationPatterns(patterns, ctx); } +//===----------------------------------------------------------------------===// +// ApplyRegisteredPassOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyRegisteredPassOp::applyToOne(Operation *target, + ApplyToEachResultList &results, + transform::TransformState &state) { + // Make sure that this transform is not applied to itself. Modifying the + // transform IR while it is being interpreted is generally dangerous. Even + // more so when applying passes because they may perform a wide range of IR + // modifications. + DiagnosedSilenceableFailure payloadCheck = + ensurePayloadIsSeparateFromTransform(*this, target); + if (!payloadCheck.succeeded()) + return payloadCheck; + + // Get pass from registry. + const PassInfo *passInfo = Pass::lookupPassInfo(getPassName()); + if (!passInfo) { + return emitDefiniteFailure() << "unknown pass: " << getPassName(); + } + + // Create pass manager with a single pass and run it. + PassManager pm(getContext()); + if (failed(passInfo->addToPipeline(pm, getOptions(), [&](const Twine &msg) { + emitError(msg); + return failure(); + }))) { + return emitDefiniteFailure() + << "failed to add pass to pipeline: " << getPassName(); + } + if (failed(pm.run(target))) { + auto diag = emitSilenceableError() << "pass pipeline failed"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + results.push_back(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @successful_pass_application( +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @successful_pass_application(%t: tensor<5xf32>) -> index { + %c0 = arith.constant 0 : index + %dim = tensor.dim %t, %c0 : tensor<5xf32> + return %dim : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @invalid_pass_name() { + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{unknown pass: non-existing-pass}} + transform.apply_registered_pass "non-existing-pass" to %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @not_isolated_from_above(%t: tensor<5xf32>) -> index { + %c0 = arith.constant 0 : index + // expected-note @below {{target op}} + // expected-error @below {{trying to schedule a pass on an operation not marked as 'IsolatedFromAbove'}} + %dim = tensor.dim %t, %c0 : tensor<5xf32> + return %dim : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{pass pipeline failed}} + transform.apply_registered_pass "canonicalize" to %1 : (!transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @invalid_pass_option() { + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{failed to add pass to pipeline: canonicalize}} + transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op +} + +// ----- + +// CHECK-LABEL: func @valid_pass_option() +func.func @valid_pass_option() { + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10149,6 +10149,7 @@ ":CastInterfaces", ":ControlFlowInterfaces", ":IR", + ":Pass", ":Rewrite", ":SideEffectInterfaces", ":Support",