diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -962,6 +962,36 @@ let hasVerifier = 1; } +def VerifyOp : TransformDialectOp<"verify", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Verifies the targeted ops"; + let description = [{ + This transform verifies the targeted ops. If at least one op fails to + verify, the transform fails definitely. + + Note: This op was designed for debugging purposes and should be used like an + assertion. It is intentional that this op produces a definite failure and + not a silenceable one. Correctness of the program should not depend on this + op. + + This transform reads the target handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def YieldOp : TransformDialectOp<"yield", [Terminator, DeclareOpInterfaceMethods]> { let summary = "Yields operation handles from a transform IR region"; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -2019,6 +2020,29 @@ effects.emplace_back(MemoryEffects::Write::get()); } +//===----------------------------------------------------------------------===// +// VerifyOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::VerifyOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + if (failed(::mlir::verify(target))) { + DiagnosedDefiniteFailure diag = emitDefiniteFailure() + << "failed to verify payload op"; + diag.attachNote(target->getLoc()) << "payload op"; + return diag; + } + return DiagnosedSilenceableFailure::success(); +} + +void transform::VerifyOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); +} + //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1915,3 +1915,31 @@ // CHECK: test_produce_param(#{{.*}}) : !transform.affine_map transform.test_produce_param(affine_map<(d0) -> ()>) : !transform.affine_map } + +// ----- + +func.func @verify_success(%arg0: f64) -> f64 { + return %arg0 : f64 +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.verify %0 : !transform.any_op +} + +// ----- + +// expected-error @below{{fail_to_verify is set}} +// expected-note @below{{payload op}} +func.func @verify_failure(%arg0: f64) -> f64 { + return %arg0 : f64 +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.test_produce_invalid_ir %0 : !transform.any_op + // expected-error @below{{failed to verify payload op}} + transform.verify %0 : !transform.any_op +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -747,6 +747,12 @@ transform::producesHandle(result, effects); } +LogicalResult mlir::test::TestDummyPayloadOp::verify() { + if (getFailToVerify()) + return emitOpError() << "fail_to_verify is set"; + return success(); +} + DiagnosedSilenceableFailure mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, @@ -892,6 +898,23 @@ transform::onlyReadsHandle(getReplacement(), effects); } +DiagnosedSilenceableFailure mlir::test::TestProduceInvalidIR::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + // Provide some IR that does not verify. + rewriter.setInsertionPointToStart(&target->getRegion(0).front()); + rewriter.create(target->getLoc(), TypeRange(), + ValueRange(), /*failToVerify=*/true); + return DiagnosedSilenceableFailure::success(); +} + +void mlir::test::TestProduceInvalidIR::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -518,10 +518,12 @@ : Op, TransformOpInterface]> { - let arguments = (ins Variadic:$args); + let arguments = (ins Variadic:$args, + UnitAttr:$fail_to_verify); let results = (outs Variadic:$outs); let assemblyFormat = "$args attr-dict `:` functional-type(operands, results)"; let cppNamespace = "::mlir::test"; + let hasVerifier = 1; let extraClassDeclaration = [{ DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, @@ -574,4 +576,21 @@ let cppNamespace = "::mlir::test"; } +def TestProduceInvalidIR + : Op, + TransformOpInterface, TransformEachOpTrait]> { + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "$target attr-dict `:` type($target)"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD