diff --git a/mlir/include/mlir/IR/Verifier.h b/mlir/include/mlir/IR/Verifier.h --- a/mlir/include/mlir/IR/Verifier.h +++ b/mlir/include/mlir/IR/Verifier.h @@ -15,8 +15,12 @@ /// Perform (potentially expensive) checks of invariants, used to detect /// compiler bugs, on this operation and any nested operations. On error, this -/// reports the error through the MLIRContext and returns failure. -LogicalResult verify(Operation *op); +/// reports the error through the MLIRContext and returns failure. If +/// `verifyRecursively` is false, this assumes that nested operations have +/// already been properly verified, and does not recursively invoke the verifier +/// on nested operations. +LogicalResult verify(Operation *op, bool verifyRecursively = true); + } // namespace mlir #endif diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -43,6 +43,11 @@ /// This class encapsulates all the state used to verify an operation region. class OperationVerifier { public: + /// If `verifyRecursively` is true, then this will also recursively verify + /// nested operations. + explicit OperationVerifier(bool verifyRecursively) + : verifyRecursively(verifyRecursively) {} + /// Verify the given operation. LogicalResult verifyOpAndDominance(Operation &op); @@ -61,6 +66,10 @@ /// Operation. LogicalResult verifyDominanceOfContainedRegions(Operation &op, DominanceInfo &domInfo); + + /// A flag indicating if this verifier should recursively verify nested + /// operations. + bool verifyRecursively; }; } // namespace @@ -81,8 +90,12 @@ return failure(); } - // Check the dominance properties and invariants of any operations in the - // regions contained by the 'opsWithIsolatedRegions' operations. + // If we aren't verifying nested operations, then we're done. + if (!verifyRecursively) + return success(); + + // Otherwise, check the dominance properties and invariants of any operations + // in the regions contained by the 'opsWithIsolatedRegions' operations. return failableParallelForEach( op.getContext(), opsWithIsolatedRegions, [&](Operation *op) { return verifyOpAndDominance(*op); }); @@ -120,21 +133,25 @@ // Check each operation, and make sure there are no branches out of the // middle of this block. - for (auto &op : block) { + for (Operation &op : block) { // Only the last instructions is allowed to have successors. if (op.getNumSuccessors() != 0 && &op != &block.back()) return op.emitError( "operation with block successors must terminate its parent block"); + // If we aren't verifying recursievly, there is nothing left to check. + if (!verifyRecursively) + continue; + // If this operation has regions and is IsolatedFromAbove, we defer // checking. This allows us to parallelize verification better. if (op.getNumRegions() != 0 && op.hasTrait()) { opsWithIsolatedRegions.push_back(&op); - } else { + // Otherwise, check the operation inline. - if (failed(verifyOperation(op, opsWithIsolatedRegions))) - return failure(); + } else if (failed(verifyOperation(op, opsWithIsolatedRegions))) { + return failure(); } } @@ -185,8 +202,9 @@ auto kindInterface = dyn_cast(op); // Verify that all child regions are ok. + MutableArrayRef regions = op.getRegions(); for (unsigned i = 0; i < numRegions; ++i) { - Region ®ion = op.getRegion(i); + Region ®ion = regions[i]; RegionKind kind = kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG; // Check that Graph Regions only have a single basic block. This is @@ -210,10 +228,13 @@ return emitError(op.getLoc(), "entry block of region may not have predecessors"); - // Verify each of the blocks within the region. - for (Block &block : region) - if (failed(verifyBlock(block, opsWithIsolatedRegions))) - return failure(); + // Verify each of the blocks within the region if we are verifying + // recursively. + if (verifyRecursively) { + for (Block &block : region) + if (failed(verifyBlock(block, opsWithIsolatedRegions))) + return failure(); + } } } @@ -330,10 +351,10 @@ } } - // Recursively verify dominance within each operation in the - // block, even if the block itself is not reachable, or we are in - // a region which doesn't respect dominance. - if (op.getNumRegions() != 0) { + // Recursively verify dominance within each operation in the block, even + // if the block itself is not reachable, or we are in a region which + // doesn't respect dominance. + if (verifyRecursively && op.getNumRegions() != 0) { // If this operation is IsolatedFromAbove, then we'll handle it in the // outer verification loop. if (op.hasTrait()) @@ -352,9 +373,7 @@ // Entrypoint //===----------------------------------------------------------------------===// -/// Perform (potentially expensive) checks of invariants, used to detect -/// compiler bugs. On error, this reports the error through the MLIRContext and -/// returns failure. -LogicalResult mlir::verify(Operation *op) { - return OperationVerifier().verifyOpAndDominance(*op); +LogicalResult mlir::verify(Operation *op, bool verifyRecursively) { + OperationVerifier verifier(/*verifyRecursively=*/true); + return verifier.verifyOpAndDominance(*op); } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -408,22 +408,24 @@ // failed). if (!passFailed && verifyPasses) { bool runVerifierNow = true; + + // If the pass is an adaptor pass, we don't run the verifier recursively + // because the nested operations should have already been verified after + // nested passes had run. + bool runVerifierRecursively = !isa(pass); + // Reduce compile time by avoiding running the verifier if the pass didn't // change the IR since the last time the verifier was run: // // 1) If the pass said that it preserved all analyses then it can't have // permuted the IR. - // 2) If we just ran an OpToOpPassAdaptor (e.g. to run function passes - // within a module) then each sub-unit will have been verified on the - // subunit (and those passes aren't allowed to modify the parent). // // We run these checks in EXPENSIVE_CHECKS mode out of caution. #ifndef EXPENSIVE_CHECKS - runVerifierNow = !isa(pass) && - !pass->passState->preservedAnalyses.isAll(); + runVerifierNow = !pass->passState->preservedAnalyses.isAll(); #endif if (runVerifierNow) - passFailed = failed(verify(op)); + passFailed = failed(verify(op, runVerifierRecursively)); } // Instrument after the pass has run. diff --git a/mlir/test/Pass/invalid-parent.mlir b/mlir/test/Pass/invalid-parent.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/invalid-parent.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.func(test-pass-invalid-parent)' -verify-diagnostics + +// Test that we properly report errors when the parent becomes invalid after running a pass +// on a child operation. +// expected-error@below {{'some_unknown_func' does not reference a valid function}} +func @TestCreateInvalidCallInPass() { + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -358,6 +358,21 @@ results.add(&dialectCanonicalizationPattern); } +//===----------------------------------------------------------------------===// +// TestCallOp +//===----------------------------------------------------------------------===// + +LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + if (!symbolTable.lookupNearestSymbolFrom(*this, fnAttr)) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + return success(); +} + //===----------------------------------------------------------------------===// // TestFoldToCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -375,6 +375,14 @@ // Test Call Interfaces //===----------------------------------------------------------------------===// +def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods]> { + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + def ConversionCallOp : TEST_Op<"conversion_call_op", [CallOpInterface]> { let arguments = (ins Variadic:$arg_operands, SymbolRefAttr:$callee); diff --git a/mlir/test/lib/Pass/CMakeLists.txt b/mlir/test/lib/Pass/CMakeLists.txt --- a/mlir/test/lib/Pass/CMakeLists.txt +++ b/mlir/test/lib/Pass/CMakeLists.txt @@ -11,4 +11,11 @@ LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRTestDialect + ) + +target_include_directories(MLIRTestPass + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test ) diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "TestDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -98,6 +99,27 @@ } }; +/// A test pass that always fails to enable testing the failure recovery +/// mechanisms of the pass manager. +class TestInvalidParentPass + : public PassWrapper> { + StringRef getArgument() const final { return "test-pass-invalid-parent"; } + StringRef getDescription() const final { + return "Test a pass in the pass manager that makes the parent operation " + "invalid"; + } + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + void runOnOperation() final { + FunctionOpInterface op = getOperation(); + OpBuilder b(getOperation().getBody()); + b.create(op.getLoc(), TypeRange(), "some_unknown_func", + ValueRange()); + } +}; + /// A test pass that contains a statistic. struct TestStatisticPass : public PassWrapper> { @@ -144,6 +166,7 @@ PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration();