diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -10,7 +10,6 @@ #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringMap.h" @@ -68,6 +67,7 @@ for (const Initializer &init : opInitializers) init(transformDialect); transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns)); + llvm::append_range(transformDialect->generatedDialects, dialectRegistrars); } protected: @@ -85,13 +85,29 @@ /// provided as template parameter. When the Transform dialect is loaded, /// dependent dialects will be loaded as well. This is intended for dialects /// that contain attributes and types used in creation and canonicalization of - /// the injected operations. + /// the injected operations. This is **not** intended for dialects that + /// contain IR entities produced in the Payload IR by transformations included + /// in this extension, use `declareGeneratedDialect` instead. template void declareDependentDialect() { dialectLoaders.push_back( [](MLIRContext *context) { context->loadDialect(); }); } + /// Declares that the transformations provided by this Transform dialect + /// extension may generate IR entities (attributes, operations or types) from + /// the dialect provided as template parameter. Passes that apply Transform + /// dialect operations to payload IR can use these to list the dialects they + /// "depend on" in the sense of creating IR entities, as expected by the pass + /// manager. This is **not** intended for dialects entities from which are + /// used in the Transform IR operations, use `declareDependentDialect` + /// instead. + template + void declareGeneratedDialect() { + dialectRegistrars.push_back( + [](DialectRegistry ®istry) { registry.insert(); }); + } + /// Injects the named constraint to make it available for use with the /// PDLMatchOp in the transform dialect. void registerPDLMatchConstraintFn(StringRef name, @@ -109,6 +125,7 @@ private: SmallVector opInitializers; SmallVector dialectLoaders; + SmallVector dialectRegistrars; /// A list of constraints that should be made availble to PDL patterns /// processed by PDLMatchOp in the Transform dialect. diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -304,6 +304,11 @@ const ::llvm::StringMap<::mlir::PDLConstraintFunction> & getPDLConstraintHooks() const; + /// Appends dialects entities from which can be generated in the Payload + /// IR when Transform ops from any of the registered extensions are + /// applied. + void registerGeneratedDialects(DialectRegistry ®istry) const; + private: template void addOperationIfNotRegistered() { @@ -344,9 +349,17 @@ void mergeInPDLMatchHooks( ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns); + /// Type of callbacks adding dialects into the dialect registry. + using DialectRegistrar = std::function; + /// A container for PDL constraint function that can be used by /// operations in this dialect. PDLPatternModule pdlMatchHooks; + + /// Callbacks registering the dialects that can be produced when applying + /// the transformation operations registered by this dialect and its + /// extensions. + SmallVector generatedDialects; }]; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -8,6 +8,7 @@ MLIRLinalgTransformOpsIncGen LINK_LIBS PUBLIC + MLIRArithmeticDialect MLIRIR MLIRLinalgDialect MLIRLinalgTransforms diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -502,14 +503,17 @@ public: LinalgTransformDialectExtension() { declareDependentDialect(); - declareDependentDialect(); - declareDependentDialect(); + + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" >(); } }; + } // namespace #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -35,3 +35,9 @@ transform::TransformDialect::getPDLConstraintHooks() const { return pdlMatchHooks.getConstraintFunctions(); } + +void transform::TransformDialect::registerGeneratedDialects( + DialectRegistry ®istry) const { + for (const DialectRegistrar &callback : generatedDialects) + callback(registry); +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" @@ -52,6 +53,12 @@ } } + void getDependentDialects(DialectRegistry ®istry, + MLIRContext &context) const override { + auto *dialect = context.getOrLoadDialect(); + dialect->registerGeneratedDialects(registry); + } + Option enableExpensiveChecks{ *this, "enable-expensive-checks", llvm::cl::init(false), llvm::cl::desc("perform expensive checks to better report errors in the " diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7392,6 +7392,7 @@ ], includes = ["include"], deps = [ + ":ArithmeticDialect", ":IR", ":LinalgDialect", ":LinalgTransformOpsIncGen",