diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -191,12 +191,13 @@ /// back to this one which accepts everything. LogicalResult verify() { return success(); } LogicalResult verifyRegions() { return success(); } + LogicalResult verifyAnalyzed(AnalysisManager *analysisManager) { + return success(); + } - /// Parse the custom form of an operation. Unless overridden, this method will - /// first try to get an operation parser from the op's dialect. Otherwise the - /// custom assembly form of an op is always rejected. Op implementations - /// should implement this to return failure. On success, they should fill in - /// result with the fields to use. + /// Unless overridden, the custom assembly form of an op is always rejected. + /// Op implementations should implement this to return failure. + /// On success, they should fill in result with the fields to use. static ParseResult parse(OpAsmParser &parser, OperationState &result); /// Print the operation. Unless overridden, this method will first try to get @@ -1832,15 +1833,18 @@ return ConcreteType::populateDefaultAttrs; } /// Implementation of `VerifyInvariantsFn` OperationName hook. - static LogicalResult verifyInvariants(Operation *op) { + static LogicalResult verifyInvariants(Operation *op, + AnalysisManager *analysisManager) { static_assert(hasNoDataMembers(), "Op class shouldn't define new data members"); return failure( failed(op_definition_impl::verifyTraits...>(op)) || - failed(cast(op).verify())); + failed(cast(op).verify()) || + failed(cast(op).verifyAnalyzed(analysisManager))); } static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { - return static_cast(&verifyInvariants); + return static_cast( + &verifyInvariants); } /// Implementation of `VerifyRegionInvariantsFn` OperationName hook. static LogicalResult verifyRegionInvariants(Operation *op) { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -32,6 +32,7 @@ } // namespace llvm namespace mlir { +class AnalysisManager; class Dialect; class DictionaryAttr; class ElementsAttr; @@ -76,8 +77,8 @@ const RegisteredOperationName &, NamedAttrList &) const>; using PrintAssemblyFn = llvm::unique_function; - using VerifyInvariantsFn = - llvm::unique_function; + using VerifyInvariantsFn = llvm::unique_function; using VerifyRegionInvariantsFn = llvm::unique_function; @@ -303,8 +304,9 @@ /// These hooks implement the verifiers for this operation. It should emits /// an error message and returns failure if a problem is detected, or returns /// success if everything is ok. - LogicalResult verifyInvariants(Operation *op) const { - return impl->verifyInvariantsFn(op); + LogicalResult verifyInvariants(Operation *op, + AnalysisManager *analysisManager) const { + return impl->verifyInvariantsFn(op, analysisManager); } LogicalResult verifyRegionInvariants(Operation *op) const { return impl->verifyRegionInvariantsFn(op); diff --git a/mlir/include/mlir/IR/Verifier.h b/mlir/include/mlir/IR/Verifier.h --- a/mlir/include/mlir/IR/Verifier.h +++ b/mlir/include/mlir/IR/Verifier.h @@ -10,6 +10,7 @@ #define MLIR_IR_VERIFIER_H namespace mlir { +class AnalysisManager; struct LogicalResult; class Operation; @@ -19,7 +20,8 @@ /// `verifyRecursively` is false, this assumes that nested operations have /// already been properly verified, and does not recursively invoke the verifier /// on nested operations. -LogicalResult verify(Operation *op, bool verifyRecursively = true); +LogicalResult verify(Operation *op, bool verifyRecursively = true, + AnalysisManager *analysisManager = nullptr); } // namespace mlir diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp --- a/mlir/lib/IR/ExtensibleDialect.cpp +++ b/mlir/lib/IR/ExtensibleDialect.cpp @@ -351,7 +351,7 @@ std::unique_ptr DynamicOpDefinition::get( StringRef name, ExtensibleDialect *dialect, OperationName::VerifyInvariantsFn &&verifyFn, - OperationName::VerifyInvariantsFn &&verifyRegionFn, + OperationName::VerifyRegionInvariantsFn &&verifyRegionFn, OperationName::ParseAssemblyFn &&parseFn, OperationName::PrintAssemblyFn &&printFn, OperationName::FoldHookFn &&foldHookFn, diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -45,8 +45,9 @@ public: /// If `verifyRecursively` is true, then this will also recursively verify /// nested operations. - explicit OperationVerifier(bool verifyRecursively) - : verifyRecursively(verifyRecursively) {} + OperationVerifier(bool verifyRecursively, AnalysisManager *analysisManager) + : analysisManager(analysisManager), verifyRecursively(verifyRecursively) { + } /// Verify the given operation. LogicalResult verifyOpAndDominance(Operation &op); @@ -66,6 +67,8 @@ LogicalResult verifyDominanceOfContainedRegions(Operation &op, DominanceInfo &domInfo); + AnalysisManager *analysisManager; + /// A flag indicating if this verifier should recursively verify nested /// operations. bool verifyRecursively; @@ -183,7 +186,8 @@ // If we can get operation info for this, check the custom hook. OperationName opName = op.getName(); Optional registeredInfo = opName.getRegisteredInfo(); - if (registeredInfo && failed(registeredInfo->verifyInvariants(&op))) + if (registeredInfo && + failed(registeredInfo->verifyInvariants(&op, analysisManager))) return failure(); SmallVector opsWithIsolatedRegions; @@ -369,7 +373,8 @@ // Entrypoint //===----------------------------------------------------------------------===// -LogicalResult mlir::verify(Operation *op, bool verifyRecursively) { - OperationVerifier verifier(verifyRecursively); +LogicalResult mlir::verify(Operation *op, bool verifyRecursively, + AnalysisManager *analysisManager) { + OperationVerifier verifier(verifyRecursively, analysisManager); return verifier.verifyOpAndDominance(*op); } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -489,7 +489,7 @@ runVerifierNow = !pass->passState->preservedAnalyses.isAll(); #endif if (runVerifierNow) - passFailed = failed(verify(op, runVerifierRecursively)); + passFailed = failed(verify(op, runVerifierRecursively, &am)); } // Instrument after the pass has run. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,9 +14,10 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H -#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" +#include "TestTypes.h" +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -41,6 +42,7 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Pass/AnalysisManager.h" namespace mlir { class DLTIDialect; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -277,7 +277,8 @@ std::unique_ptr getDynamicGenericOp(TestDialect *dialect) { return DynamicOpDefinition::get( - "dynamic_generic", dialect, [](Operation *op) { return success(); }, + "dynamic_generic", dialect, + [](Operation *op, AnalysisManager *am) { return success(); }, [](Operation *op) { return success(); }); } @@ -285,7 +286,7 @@ getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { return DynamicOpDefinition::get( "dynamic_one_operand_two_results", dialect, - [](Operation *op) { + [](Operation *op, AnalysisManager *am) { if (op->getNumOperands() != 1) { op->emitOpError() << "expected 1 operand, but had " << op->getNumOperands(); @@ -303,7 +304,7 @@ std::unique_ptr getDynamicCustomParserPrinterOp(TestDialect *dialect) { - auto verifier = [](Operation *op) { + auto verifier = [](Operation *op, AnalysisManager *am) { if (op->getNumOperands() == 0 && op->getNumResults() == 0) return success(); op->emitError() << "operation should have no operands and no results"; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -921,6 +921,20 @@ // Test Patterns //===----------------------------------------------------------------------===// +def OpWithAMVerifier : TEST_Op<"op_with_am_verifier"> { + let extraClassDeclaration = [{ + ::mlir::LogicalResult verifyAnalyzed(::mlir::AnalysisManager *am) { + if (am) { + llvm::errs() << "has analysis manager\n"; + llvm::errs() << &am->getAnalysis<::mlir::DataLayoutAnalysis>() << "\n"; + } else { + llvm::errs() << "no analysis manager\n"; + } + return ::mlir::success(); + } + }]; +} + def OpA : TEST_Op<"op_a"> { let arguments = (ins I32, I32Attr:$attr); let results = (outs I32); diff --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp --- a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp +++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp @@ -19,7 +19,8 @@ void registerTestDynDialect(DialectRegistry ®istry) { registry.insertDynamic( "test_dyn", [](MLIRContext *ctx, DynamicDialect *testDyn) { - auto opVerifier = [](Operation *op) -> LogicalResult { + auto opVerifier = [](Operation *op, + AnalysisManager *am) -> LogicalResult { if (op->getNumOperands() == 0 && op->getNumResults() == 1 && op->getNumRegions() == 0) return success(); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -337,6 +337,7 @@ ":TestTypeDefsIncGen", "//llvm:Support", "//mlir:ArithDialect", + "//mlir:Analysis", "//mlir:ControlFlowInterfaces", "//mlir:CopyOpInterface", "//mlir:DLTIDialect",