diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -52,6 +52,26 @@ /// expected to derive this class and register operations in the constructor. /// They can be registered with the DialectRegistry and automatically applied /// to the Transform dialect when it is loaded. +/// +/// Derived classes are expected to define a `void init()` function in which +/// they can call various protected methods of the base class to register +/// extension operations and declare their dependencies. +/// +/// By default, the extension is configured both for construction of the +/// Transform IR and for its application to some payload. If only the +/// construction is desired, the extension can be switched to "build-only" mode +/// that avoids loading the dialects that are only necessary for transforming +/// the payload. To perform the switch, the extension must be wrapped into the +/// `BuildOnly` class template (see below) when it is registered, as in: +/// +/// dialectRegistry.addExtension>(); +/// +/// instead of: +/// +/// dialectRegistry.addExtension(); +/// +/// Derived classes must reexport the constructor of this class or otherwise +/// forward its boolean argument to support this behavior. template class TransformDialectExtension : public DialectExtension { @@ -65,12 +85,31 @@ ExtraDialects *...) const final { for (const DialectLoader &loader : dialectLoaders) loader(context); + + // Only load generated dialects if the user intends to apply + // transformations specified by the extension. + if (!buildOnly) + for (const DialectLoader &loader : generatedDialectLoaders) + loader(context); + for (const Initializer &init : opInitializers) init(transformDialect); transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns)); } protected: + using Base = TransformDialectExtension; + + /// Extension constructor. The argument indicates whether to skip generated + /// dialects when applying the extension. + explicit TransformDialectExtension(bool buildOnly = false) + : buildOnly(buildOnly) { + static_cast(this)->init(); + } + + /// Hook for derived classes to inject constructor behavior. + void init() {} + /// Injects the operations into the Transform dialect. The operations must /// implement the TransformOpInterface and MemoryEffectsOpInterface, and the /// implementations must be already available when the operation is injected. @@ -85,13 +124,28 @@ /// provided as template parameter. When the Transform dialect is loaded, /// dependent dialects will be loaded as well. This is intended for dialects /// that contain attributes and types used in creation and canonicalization of - /// the injected operations. + /// the injected operations, similarly to how the dialect definition may list + /// dependent dialects. This is *not* intended for dialects entities from + /// which may be produced when applying the transformations specified by ops + /// registered by this extension. template void declareDependentDialect() { dialectLoaders.push_back( [](MLIRContext *context) { context->loadDialect(); }); } + /// Declares that the transformations associated with the operations + /// registered by this dialect extension may produce operations from the + /// dialect provided as template parameter while processing payload IR that + /// does not contain the operations from said dialect. This is similar to + /// dependent dialects of a pass. These dialects will be loaded along with the + /// transform dialect unless the extension is in the build-only mode. + template + void declareGeneratedDialect() { + generatedDialectLoaders.push_back( + [](MLIRContext *context) { context->loadDialect(); }); + } + /// Injects the named constraint to make it available for use with the /// PDLMatchOp in the transform dialect. void registerPDLMatchConstraintFn(StringRef name, @@ -108,14 +162,32 @@ private: SmallVector opInitializers; + + /// Callbacks loading the dependent dialects, i.e. the dialect needed for the + /// extension ops. SmallVector dialectLoaders; - /// A list of constraints that should be made availble to PDL patterns + /// Callbacks loading the generated dialects, i.e. the dialects produced when + /// applying the transformations. + SmallVector generatedDialectLoaders; + + /// A list of constraints that should be made available to PDL patterns /// processed by PDLMatchOp in the Transform dialect. /// /// Declared as mutable so its contents can be moved in the `apply` const /// method, which is only called once. mutable llvm::StringMap pdlMatchConstraintFns; + + /// Indicates that the extension is in build-only mode. + bool buildOnly; +}; + +/// A wrapper for transform dialect extensions that forces them to be +/// constructed in the build-only mode. +template +class BuildOnly : public DerivedTy { +public: + BuildOnly() : DerivedTy(/*buildOnly=*/true) {} }; } // namespace transform diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -75,10 +75,14 @@ : public transform::TransformDialectExtension< BufferizationTransformDialectExtension> { public: - BufferizationTransformDialectExtension() { - declareDependentDialect(); + using Base::Base; + + void init() { declareDependentDialect(); - declareDependentDialect(); + + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1069,12 +1069,16 @@ : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { public: - LinalgTransformDialectExtension() { - declareDependentDialect(); - declareDependentDialect(); + using Base::Base; + + void init() { declareDependentDialect(); - declareDependentDialect(); - declareDependentDialect(); + + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/SCF/Transforms/Transforms.h" @@ -233,9 +234,14 @@ : public transform::TransformDialectExtension< SCFTransformDialectExtension> { public: - SCFTransformDialectExtension() { - declareDependentDialect(); - declareDependentDialect(); + using Base::Base; + + void init() { + declareDependentDialect(); + + declareGeneratedDialect(); + declareGeneratedDialect(); + registerTransformOps< #define GET_OP_LIST #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -295,7 +295,9 @@ : public transform::TransformDialectExtension< TestTransformDialectExtension> { public: - TestTransformDialectExtension() { + using Base::Base; + + void init() { declareDependentDialect(); registerTransformOps { +public: + using Base::Base; + void init() { declareGeneratedDialect(); } +}; +} // end namespace + +TEST(BuildOnlyExtensionTest, buildOnlyExtension) { + // Register the build-only version of the transform dialect extension. The + // func dialect is declared as generated so it should not be loaded along with + // the transform dialect. + DialectRegistry registry; + registry.addExtensions>(); + MLIRContext ctx(registry); + ctx.getOrLoadDialect(); + ASSERT_FALSE(ctx.getLoadedDialect()); +} + +TEST(BuildOnlyExtensionTest, buildAndApplyExtension) { + // Register the full version of the transform dialect extension. The func + // dialect should be loaded along with the transform dialect. + DialectRegistry registry; + registry.addExtensions(); + MLIRContext ctx(registry); + ctx.getOrLoadDialect(); + ASSERT_TRUE(ctx.getLoadedDialect()); +} diff --git a/mlir/unittests/Dialect/Transform/CMakeLists.txt b/mlir/unittests/Dialect/Transform/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Dialect/Transform/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRTransformDialectTests + BuildOnlyExtensionTest.cpp +) +target_link_libraries(MLIRTransformDialectTests + PRIVATE + MLIRFuncDialect + MLIRTransformDialect +) diff --git a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel @@ -179,6 +179,21 @@ ], ) +cc_test( + name = "transform_dialect_tests", + size = "small", + srcs = glob([ + "Dialect/Transform/*.cpp", + "Dialect/Transform/*.h", + ]), + deps = [ + "//llvm:TestingSupport", + "//llvm:gtest_main", + "//mlir:FuncDialect", + "//mlir:TransformDialect", + ], +) + cc_test( name = "dialect_utils_tests", size = "small",