diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(PDLExtension) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -12,12 +12,52 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" #include namespace mlir { namespace transform { + +namespace detail { +/// Concrete base class for CRTP TransformDialectDataBase. Must not be used +/// directly. +class TransformDialectDataBase { +public: + virtual ~TransformDialectDataBase() = default; + + /// Returns the dynamic type ID of the subclass. + TypeID getTypeID() const { return typeID; } + +protected: + /// Must be called by the subclass with the appropriate type ID. + explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {} + +private: + /// The type ID of the subclass. + const TypeID typeID; +}; +} // namespace detail + +/// Base class for additional data owned by the Transform dialect. Extensions +/// may communicate with each other using this data. The data object is +/// identified by the TypeID of the specific data subclass, querying the data of +/// the same subclass returns a reference to the same object. When a Transform +/// dialect extension is initialized, it can populate the data in the specific +/// subclass. When a Transform op is applied, it can read (but not mutate) the +/// data in the specific subclass, including the data provided by other +/// extensions. +/// +/// This follows CRTP: derived classes must list themselves as template +/// argument. +template +class TransformDialectData : public detail::TransformDialectDataBase { +protected: + /// Forward the TypeID of the derived class to the base. + TransformDialectData() : TransformDialectDataBase(TypeID::get()) {} +}; + #ifndef NDEBUG namespace detail { /// Asserts that the operations provided as template arguments implement the @@ -85,9 +125,8 @@ for (const DialectLoader &loader : generatedDialectLoaders) loader(context); - for (const Initializer &init : opInitializers) + for (const Initializer &init : initializers) init(transformDialect); - transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns)); } protected: @@ -100,6 +139,41 @@ static_cast(this)->init(); } + /// Registers a custom initialization step to be performed when the extension + /// is applied to the dialect while loading. This is discouraged in favor of + /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer` + /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It + /// will be called during the extension initialization and given the current + /// MLIR context. This may be used to attach additional interfaces that cannot + /// be attached elsewhere. + template + void addCustomInitializationStep(Func &&func) { + std::function initializer = func; + dialectLoaders.push_back( + [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); }); + } + + /// Registers the given function as one of the initializers for the + /// dialect-owned data of the kind specified as template argument. The + /// function must be convertible to the `void (DataTy &)` form. It will be + /// called during the extension initialization and will be given a mutable + /// reference to `DataTy`. The callback is expected to append data to the + /// given storage, and is not allowed to remove or destructively mutate the + /// existing data. The order in which callbacks from different extensions are + /// executed is unspecified so the callbacks may not rely on data being + /// already present. `DataTy` must be a class deriving `TransformDialectData`. + template + void addDialectDataInitializer(Func &&func) { + static_assert(std::is_base_of_v, + "only classes deriving TransformDialectData are accepted"); + + std::function initializer = func; + initializers.push_back( + [init = std::move(initializer)](TransformDialect *transformDialect) { + init(transformDialect->getOrCreateExtraData()); + }); + } + /// Hook for derived classes to inject constructor behavior. void init() {} @@ -108,7 +182,7 @@ /// implementations must be already available when the operation is injected. template void registerTransformOps() { - opInitializers.push_back([](TransformDialect *transformDialect) { + initializers.push_back([](TransformDialect *transformDialect) { transformDialect->addOperationsChecked(); }); } @@ -120,7 +194,7 @@ /// `StringRef` that is unique across all injected types. template void registerTypes() { - opInitializers.push_back([](TransformDialect *transformDialect) { + initializers.push_back([](TransformDialect *transformDialect) { transformDialect->addTypesChecked(); }); } @@ -151,22 +225,10 @@ [](MLIRContext *context) { context->loadDialect(); }); } - /// Injects the named constraint to make it available for use with the - /// PDLMatchOp in the transform dialect. - void registerPDLMatchConstraintFn(StringRef name, - PDLConstraintFunction &&fn) { - pdlMatchConstraintFns.try_emplace(name, - std::forward(fn)); - } - template - void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) { - pdlMatchConstraintFns.try_emplace( - name, ::mlir::detail::pdl_function_builder::buildConstraintFn( - std::forward(fn))); - } - private: - SmallVector opInitializers; + /// Callbacks performing extension initialization, e.g., registering ops, + /// types and defining the additional data. + SmallVector initializers; /// Callbacks loading the dependent dialects, i.e. the dialect needed for the /// extension ops. @@ -176,13 +238,6 @@ /// applying the transformations. SmallVector generatedDialectLoaders; - /// A list of constraints that should be made available to PDL patterns - /// processed by PDLMatchOp in the Transform dialect. - /// - /// Declared as mutable so its contents can be moved in the `apply` const - /// method, which is only called once. - mutable llvm::StringMap pdlMatchConstraintFns; - /// Indicates that the extension is in build-only mode. bool buildOnly; }; @@ -232,6 +287,17 @@ #endif // NDEBUG } +template +DataTy &TransformDialect::getOrCreateExtraData() { + TypeID typeID = TypeID::get(); + auto it = extraData.find(typeID); + if (it != extraData.end()) + return static_cast(*it->getSecond()); + + auto emplaced = extraData.try_emplace(typeID, std::make_unique()); + return static_cast(*emplaced.first->getSecond()); +} + /// A wrapper for transform dialect extensions that forces them to be /// constructed in the build-only mode. template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -18,36 +18,31 @@ let name = "transform"; let cppNamespace = "::mlir::transform"; - let dependentDialects = [ - "::mlir::pdl::PDLDialect", - "::mlir::pdl_interp::PDLInterpDialect", - ]; - let hasOperationAttrVerify = 1; let usePropertiesForAttributes = 1; let extraClassDeclaration = [{ /// Name of the attribute attachable to the symbol table operation /// containing named sequences. This is used to trigger verification. - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName = "transform.with_named_sequence"; /// Names of the attribute attachable to an operation so it can be /// identified as root by the default interpreter pass. - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kTargetTagAttrName = "transform.target_tag"; /// Names of the attributes indicating whether an argument of an external /// transform dialect symbol is consumed or only read. - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kArgConsumedAttrName = "transform.consumed"; - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName = "transform.readonly"; - /// Returns the named PDL constraint functions available in the dialect - /// as a map from their name to the function. - const ::llvm::StringMap<::mlir::PDLConstraintFunction> & - getPDLConstraintHooks() const; + template + const DataTy &getExtraData() const { + return *static_cast(extraData.at(::mlir::TypeID::get()).get()); + } /// Parses a type registered by this dialect or one of its extensions. ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; @@ -92,23 +87,27 @@ /// mnemonic. [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic); + /// Registers dialect types with the context. void initializeTypes(); + // Give extensions access to injection functions. template friend class TransformDialectExtension; - /// Takes ownership of the named PDL constraint function from the given - /// map and makes them available for use by the operations in the dialect. - void mergeInPDLMatchHooks( - ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns); + /// Gets a mutable reference to extra data of the kind specified as + /// template argument. Allocates the data on the first call. + template + DataTy &getOrCreateExtraData(); //===----------------------------------------------------------------===// // Data fields //===----------------------------------------------------------------===// - /// A container for PDL constraint function that can be used by - /// operations in this dialect. - ::mlir::PDLPatternModule pdlMatchHooks; + /// Additional data associated with and owned by the dialect. Accessible + /// to extensions. + ::llvm::DenseMap<::mlir::TypeID, std::unique_ptr< + ::mlir::transform::detail::TransformDialectDataBase>> + extraData; /// A map from type mnemonic to its parsing function for the remainder of /// the syntax. The parser has access to the mnemonic, so it is used for diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -38,6 +38,14 @@ /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); +/// Populates `effects` with side effects implied by +/// PossibleTopLevelTransformOpTrait for the given operation. The operation may +/// have an optional `root` operand, indicating it is not in fact top-level. It +/// is also expected to have a single-block body. +void getPotentialTopLevelEffects( + Operation *operation, Value root, Block &body, + SmallVectorImpl &effects); + /// Verification hook for TransformOpInterface. LogicalResult verifyTransformOpInterface(Operation *op); @@ -753,15 +761,16 @@ /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some /// control flow logic specific to the current operation. The operations with -/// this trait are expected to have at least one single-block region with one -/// argument of PDL Operation type. The operations are also expected to be valid -/// without operands, in which case they are considered top-level, and with one -/// or more arguments, in which case they are considered nested. Top-level -/// operations have the block argument of the entry block in the Transform IR -/// correspond to the root operation of Payload IR. Nested operations have the -/// block argument of the entry block in the Transform IR correspond to a list -/// of Payload IR operations mapped to the first operand of the Transform IR -/// operation. The operation must implement TransformOpInterface. +/// this trait are expected to have at least one single-block region with at +/// least one argument of type implementing TransformHandleTypeInterface. The +/// operations are also expected to be valid without operands, in which case +/// they are considered top-level, and with one or more arguments, in which case +/// they are considered nested. Top-level operations have the block argument of +/// the entry block in the Transform IR correspond to the root operation of +/// Payload IR. Nested operations have the block argument of the entry block in +/// the Transform IR correspond to a list of Payload IR operations mapped to the +/// first operand of the Transform IR operation. The operation must implement +/// TransformOpInterface. template class PossibleTopLevelTransformOpTrait : public OpTrait::TraitBase { @@ -777,6 +786,14 @@ return &this->getOperation()->getRegion(region).front(); } + /// Populates `effects` with side effects implied by this trait. + void getPotentialTopLevelEffects( + SmallVectorImpl &effects) { + detail::getPotentialTopLevelEffects( + this->getOperation(), cast(this->getOperation()).getRoot(), + *getBodyBlock(), effects); + } + /// Sets up the mapping between the entry block of the given region of this op /// and the relevant list of Payload IR operations in the given state. The /// state is expected to be already scoped at the region of this operation. diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -9,7 +9,6 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H -#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -575,37 +575,6 @@ let assemblyFormat = "$value attr-dict `->` type($param)"; } -def PDLMatchOp : TransformDialectOp<"pdl_match", - [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Finds ops that match the named PDL pattern"; - let description = [{ - Find Payload IR ops nested within the Payload IR op associated with the - operand that match the PDL pattern identified by its name. The pattern is - expected to be defined in the closest surrounding `WithPDLPatternsOp`. - - Produces a Transform IR value associated with the list of Payload IR ops - that matched the pattern. The order of results in the list is that of the - Operation::walk, clients are advised not to rely on a specific order though. - If the operand is associated with multiple Payload IR ops, finds matching - ops nested within each of those and produces a single list containing all - of the matched ops. - - The transformation is considered successful regardless of whether some - Payload IR ops actually matched the pattern and only fails if the pattern - could not be looked up or compiled. - }]; - - let arguments = (ins - Arg:$root, - SymbolRefAttr:$pattern_name); - let results = (outs - Res:$matched); - - let assemblyFormat = "$pattern_name `in` $root attr-dict `:` " - "functional-type(operands, results)"; -} - def PrintOp : TransformDialectOp<"print", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -753,61 +722,6 @@ let hasVerifier = 1; } -def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", - [DeclareOpInterfaceMethods, NoTerminator, - OpAsmOpInterface, PossibleTopLevelTransformOpTrait, - DeclareOpInterfaceMethods, - SymbolTable]> { - let summary = "Contains PDL patterns available for use in transforms"; - let description = [{ - This op contains a set of named PDL patterns that are available for the - Transform dialect operations to be used for pattern matching. For example, - PDLMatchOp can be used to produce a Transform IR value associated with all - Payload IR operations that match the pattern as follows: - - ```mlir - transform.with_pdl_patterns { - ^bb0(%arg0: !transform.any_op): - pdl.pattern @my_pattern : benefit(1) { - %0 = pdl.operation //... - // Regular PDL goes here. - pdl.rewrite %0 with "transform.dialect" - } - - sequence %arg0 failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %1 = pdl_match @my_pattern in %arg1 - // Use %1 as handle - } - } - ``` - - Note that the pattern is expected to finish with a `pdl.rewrite` terminator - that points to the custom rewriter named "transform.dialect". The rewriter - actually does nothing, but the transform application will keep track of the - operations that matched the pattern. - - This op is expected to contain `pdl.pattern` operations and exactly one - another Transform dialect operation that gets executed with all patterns - available. This op is a possible top-level Transform IR op, the argument of - its entry block corresponds to either the root op of the payload IR or the - ops associated with its operand when provided. - }]; - - let arguments = (ins - Arg, "Root operation of the Payload IR" - >:$root); - let regions = (region SizedRegion<1>:$body); - let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions"; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - /// Allow the dialect prefix to be omitted. - static StringRef getDefaultDialect() { return "transform"; } - }]; -} - def YieldOp : TransformDialectOp<"yield", [Terminator, DeclareOpInterfaceMethods]> { let summary = "Yields operation handles from a transform IR region"; diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS PDLExtensionOps.td) +mlir_tablegen(PDLExtensionOps.h.inc -gen-op-decls) +mlir_tablegen(PDLExtensionOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen) + +add_mlir_doc(PDLExtensionOps PDLExtensionOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h @@ -0,0 +1,16 @@ +//===- PDLExtension.h - PDL extension for Transform dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +namespace mlir { +class DialectRegistry; + +namespace transform { +/// Registers the PDL extension of the Transform dialect in the given registry. +void registerPDLExtension(DialectRegistry &dialectRegistry); +} // namespace transform +} // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h @@ -0,0 +1,49 @@ +//===- PDLExtensionOps.h - PDL extension for Transform dialect --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H +#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc" + +namespace mlir { +namespace transform { +/// PDL constraint callbacks that can be used by the PDL extension of the +/// Transform dialect. These are owned by the Transform dialect and can be +/// populated by extensions. +class PDLMatchHooks : public TransformDialectData { +public: + /// Takes ownership of the named PDL constraint function from the given + /// map and makes them available for use by the operations in the dialect. + void + mergeInPDLMatchHooks(llvm::StringMap &&constraintFns); + + /// Returns the named PDL constraint functions available in the dialect + /// as a map from their name to the function. + const llvm::StringMap<::mlir::PDLConstraintFunction> & + getPDLConstraintHooks() const; + +private: + /// A container for PDL constraint function that can be used by + /// operations in this dialect. + PDLPatternModule pdlMatchHooks; +}; +} // namespace transform +} // namespace mlir + +MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) + +#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td @@ -0,0 +1,104 @@ +//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS +#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" + +def PDLMatchOp : TransformDialectOp<"pdl_match", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Finds ops that match the named PDL pattern"; + let description = [{ + Find Payload IR ops nested within the Payload IR op associated with the + operand that match the PDL pattern identified by its name. The pattern is + expected to be defined in the closest surrounding `WithPDLPatternsOp`. + + Produces a Transform IR value associated with the list of Payload IR ops + that matched the pattern. The order of results in the list is that of the + Operation::walk, clients are advised not to rely on a specific order though. + If the operand is associated with multiple Payload IR ops, finds matching + ops nested within each of those and produces a single list containing all + of the matched ops. + + The transformation is considered successful regardless of whether some + Payload IR ops actually matched the pattern and only fails if the pattern + could not be looked up or compiled. + }]; + + let arguments = (ins + Arg:$root, + SymbolRefAttr:$pattern_name); + let results = (outs + Res:$matched); + + let assemblyFormat = "$pattern_name `in` $root attr-dict `:` " + "functional-type(operands, results)"; +} + +def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", + [DeclareOpInterfaceMethods, NoTerminator, + OpAsmOpInterface, PossibleTopLevelTransformOpTrait, + DeclareOpInterfaceMethods, + SymbolTable]> { + let summary = "Contains PDL patterns available for use in transforms"; + let description = [{ + This op contains a set of named PDL patterns that are available for the + Transform dialect operations to be used for pattern matching. For example, + PDLMatchOp can be used to produce a Transform IR value associated with all + Payload IR operations that match the pattern as follows: + + ```mlir + transform.with_pdl_patterns { + ^bb0(%arg0: !transform.any_op): + pdl.pattern @my_pattern : benefit(1) { + %0 = pdl.operation //... + // Regular PDL goes here. + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %1 = pdl_match @my_pattern in %arg1 + // Use %1 as handle + } + } + ``` + + Note that the pattern is expected to finish with a `pdl.rewrite` terminator + that points to the custom rewriter named "transform.dialect". The rewriter + actually does nothing, but the transform application will keep track of the + operations that matched the pattern. + + This op is expected to contain `pdl.pattern` operations and exactly one + another Transform dialect operation that gets executed with all patterns + available. This op is a possible top-level Transform IR op, the argument of + its entry block corresponds to either the root op of the payload IR or the + ops associated with its operand when provided. + }]; + + let arguments = (ins + Arg, "Root operation of the Payload IR" + >:$root); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "transform"; } + }]; +} + +#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -76,6 +76,7 @@ #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" @@ -135,6 +136,7 @@ memref::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); tensor::registerTransformDialectExtension(registry); + transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); // Register all external models. diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) +add_subdirectory(PDLExtension) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -14,8 +14,6 @@ LINK_LIBS PUBLIC MLIRIR MLIRParser - MLIRPDLDialect - MLIRPDLInterpDialect MLIRRewrite MLIRSideEffectInterfaces MLIRTransforms diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -8,8 +8,6 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Analysis/CallGraph.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" @@ -51,18 +49,6 @@ } #endif // NDEBUG -namespace { -struct PDLOperationTypeTransformHandleTypeInterfaceImpl - : public transform::TransformHandleTypeInterface::ExternalModel< - PDLOperationTypeTransformHandleTypeInterfaceImpl, - pdl::OperationType> { - DiagnosedSilenceableFailure - checkPayload(Type type, Location loc, ArrayRef payload) const { - return DiagnosedSilenceableFailure::success(); - } -}; -} // namespace - void transform::TransformDialect::initialize() { // Using the checked versions to enable the same assertions as for the ops // from extensions. @@ -71,21 +57,6 @@ #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); initializeTypes(); - - pdl::OperationType::attachInterface< - PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext()); -} - -void transform::TransformDialect::mergeInPDLMatchHooks( - llvm::StringMap &&constraintFns) { - // Steal the constraint functions from the given map. - for (auto &it : constraintFns) - pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); -} - -const llvm::StringMap & -transform::TransformDialect::getPDLConstraintHooks() const { - return pdlMatchHooks.getConstraintFunctions(); } Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1242,6 +1242,61 @@ // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// +/// Appends to `effects` the memory effect instances on `target` with the same +/// resource and effect as the ones the operation `iface` having on `source`. +static void +remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, + SmallVectorImpl &effects) { + SmallVector nestedEffects; + iface.getEffectsOnValue(source, nestedEffects); + for (const auto &effect : nestedEffects) + effects.emplace_back(effect.getEffect(), target, effect.getResource()); +} + +/// Appends to `effects` the same effects as the operations of `block` have on +/// block arguments but associated with `operands.` +static void +remapArgumentEffects(Block &block, ValueRange operands, + SmallVectorImpl &effects) { + for (Operation &op : block) { + auto iface = dyn_cast(&op); + if (!iface) + continue; + + for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { + remapEffects(iface, source, target, effects); + } + + SmallVector nestedEffects; + iface.getEffectsOnResource(transform::PayloadIRResource::get(), + nestedEffects); + llvm::append_range(effects, nestedEffects); + } +} + +void transform::detail::getPotentialTopLevelEffects( + Operation *operation, Value root, Block &body, + SmallVectorImpl &effects) { + transform::onlyReadsHandle(operation->getOperands(), effects); + transform::producesHandle(operation->getResults(), effects); + + if (!root) { + for (Operation &op : body) { + auto iface = dyn_cast(&op); + if (!iface) + continue; + + SmallVector nestedEffects; + iface.getEffects(effects); + } + return; + } + + // Carry over all effects on arguments of the entry block as those on the + // operands, this is the same value just remapped. + remapArgumentEffects(body, operation->getOperands(), effects); +} + LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -17,8 +16,6 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" @@ -52,99 +49,6 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" -//===----------------------------------------------------------------------===// -// PatternApplicatorExtension -//===----------------------------------------------------------------------===// - -namespace { -/// A TransformState extension that keeps track of compiled PDL pattern sets. -/// This is intended to be used along the WithPDLPatterns op. The extension -/// can be constructed given an operation that has a SymbolTable trait and -/// contains pdl::PatternOp instances. The patterns are compiled lazily and one -/// by one when requested; this behavior is subject to change. -class PatternApplicatorExtension : public transform::TransformState::Extension { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) - - /// Creates the extension for patterns contained in `patternContainer`. - explicit PatternApplicatorExtension(transform::TransformState &state, - Operation *patternContainer) - : Extension(state), patterns(patternContainer) {} - - /// Appends to `results` the operations contained in `root` that matched the - /// PDL pattern with the given name. Note that `root` may or may not be the - /// operation that contains PDL patterns. Reports an error if the pattern - /// cannot be found. Note that when no operations are matched, this still - /// succeeds as long as the pattern exists. - LogicalResult findAllMatches(StringRef patternName, Operation *root, - SmallVectorImpl &results); - -private: - /// Map from the pattern name to a singleton set of rewrite patterns that only - /// contains the pattern with this name. Populated when the pattern is first - /// requested. - // TODO: reconsider the efficiency of this storage when more usage data is - // available. Storing individual patterns in a set and triggering compilation - // for each of them has overhead. So does compiling a large set of patterns - // only to apply a handlful of them. - llvm::StringMap compiledPatterns; - - /// A symbol table operation containing the relevant PDL patterns. - SymbolTable patterns; -}; - -LogicalResult PatternApplicatorExtension::findAllMatches( - StringRef patternName, Operation *root, - SmallVectorImpl &results) { - auto it = compiledPatterns.find(patternName); - if (it == compiledPatterns.end()) { - auto patternOp = patterns.lookup(patternName); - if (!patternOp) - return failure(); - - // Copy the pattern operation into a new module that is compiled and - // consumed by the PDL interpreter. - OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); - auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); - builder.clone(*patternOp); - PDLPatternModule patternModule(std::move(pdlModuleOp)); - - // Merge in the hooks owned by the dialect. Make a copy as they may be - // also used by the following operations. - auto *dialect = - root->getContext()->getLoadedDialect(); - for (const auto &[name, constraintFn] : dialect->getPDLConstraintHooks()) - patternModule.registerConstraintFunction(name, constraintFn); - - // Register a noop rewriter because PDL requires patterns to end with some - // rewrite call. - patternModule.registerRewriteFunction( - "transform.dialect", [](PatternRewriter &, Operation *) {}); - - it = compiledPatterns - .try_emplace(patternOp.getName(), std::move(patternModule)) - .first; - } - - PatternApplicator applicator(it->second); - // We want to discourage direct use of PatternRewriter in APIs but In this - // very specific case, an IRRewriter is not enough. - struct TrivialPatternRewriter : public PatternRewriter { - public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} - }; - TrivialPatternRewriter rewriter(root->getContext()); - applicator.applyDefaultCostModel(); - root->walk([&](Operation *op) { - if (succeeded(applicator.matchAndRewrite(op, rewriter))) - results.push_back(op); - }); - - return success(); -} -} // namespace - //===----------------------------------------------------------------------===// // TrackingListener //===----------------------------------------------------------------------===// @@ -420,10 +324,7 @@ assert(outputs.size() == 1 && "expected one output"); return llvm::all_of( std::initializer_list{inputs.front(), outputs.front()}, - [](Type ty) { - return llvm::isa(ty); - }); + [](Type ty) { return isa(ty); }); } //===----------------------------------------------------------------------===// @@ -1031,38 +932,6 @@ return result; } -/// Appends to `effects` the memory effect instances on `target` with the same -/// resource and effect as the ones the operation `iface` having on `source`. -static void -remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, - SmallVectorImpl &effects) { - SmallVector nestedEffects; - iface.getEffectsOnValue(source, nestedEffects); - for (const auto &effect : nestedEffects) - effects.emplace_back(effect.getEffect(), target, effect.getResource()); -} - -/// Appends to `effects` the same effects as the operations of `block` have on -/// block arguments but associated with `operands.` -static void -remapArgumentEffects(Block &block, ValueRange operands, - SmallVectorImpl &effects) { - for (Operation &op : block) { - auto iface = dyn_cast(&op); - if (!iface) - continue; - - for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { - remapEffects(iface, source, target, effects); - } - - SmallVector nestedEffects; - iface.getEffectsOnResource(transform::PayloadIRResource::get(), - nestedEffects); - llvm::append_range(effects, nestedEffects); - } -} - static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op); @@ -1474,8 +1343,7 @@ void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, Value target, int64_t numResultHandles) { result.addOperands(target); - auto pdlOpType = pdl::OperationType::get(builder.getContext()); - result.addTypes(SmallVector(numResultHandles, pdlOpType)); + result.addTypes(SmallVector(numResultHandles, target.getType())); } DiagnosedSilenceableFailure @@ -1535,35 +1403,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// PDLMatchOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::PDLMatchOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - auto *extension = state.getExtension(); - assert(extension && - "expected PatternApplicatorExtension to be attached by the parent op"); - SmallVector targets; - for (Operation *root : state.getPayloadOps(getRoot())) { - if (failed(extension->findAllMatches( - getPatternName().getLeafReference().getValue(), root, targets))) { - emitDefiniteFailure() - << "could not find pattern '" << getPatternName() << "'"; - } - } - results.set(llvm::cast(getResult()), targets); - return DiagnosedSilenceableFailure::success(); -} - -void transform::PDLMatchOp::getEffects( - SmallVectorImpl &effects) { - onlyReadsHandle(getRoot(), effects); - producesHandle(getMatched(), effects); - onlyReadsPayload(effects); -} - //===----------------------------------------------------------------------===// // ReplicateOp //===----------------------------------------------------------------------===// @@ -1776,37 +1615,9 @@ return success(); } -/// Populate `effects` with transform dialect memory effects for the potential -/// top-level operation. Such operations have recursive effects from nested -/// operations. When they have an operand, we can additionally remap effects on -/// the block argument to be effects on the operand. -template -static void getPotentialTopLevelEffects( - OpTy operation, SmallVectorImpl &effects) { - transform::onlyReadsHandle(operation->getOperands(), effects); - transform::producesHandle(operation->getResults(), effects); - - if (!operation.getRoot()) { - for (Operation &op : *operation.getBodyBlock()) { - auto iface = dyn_cast(&op); - if (!iface) - continue; - - SmallVector nestedEffects; - iface.getEffects(effects); - } - return; - } - - // Carry over all effects on arguments of the entry block as those on the - // operands, this is the same value just remapped. - remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(), - effects); -} - void transform::SequenceOp::getEffects( SmallVectorImpl &effects) { - getPotentialTopLevelEffects(*this, effects); + getPotentialTopLevelEffects(effects); } OperandRange transform::SequenceOp::getSuccessorEntryOperands( @@ -1908,77 +1719,6 @@ buildSequenceBody(builder, state, bbArgType, extraBindingTypes, bodyBuilder); } -//===----------------------------------------------------------------------===// -// WithPDLPatternsOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::WithPDLPatternsOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - TransformOpInterface transformOp = nullptr; - for (Operation &nested : getBody().front()) { - if (!isa(nested)) { - transformOp = cast(nested); - break; - } - } - - state.addExtension(getOperation()); - auto guard = llvm::make_scope_exit( - [&]() { state.removeExtension(); }); - - auto scope = state.make_region_scope(getBody()); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); - return state.applyTransform(transformOp); -} - -void transform::WithPDLPatternsOp::getEffects( - SmallVectorImpl &effects) { - getPotentialTopLevelEffects(*this, effects); -} - -LogicalResult transform::WithPDLPatternsOp::verify() { - Block *body = getBodyBlock(); - Operation *topLevelOp = nullptr; - for (Operation &op : body->getOperations()) { - if (isa(op)) - continue; - - if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { - if (topLevelOp) { - InFlightDiagnostic diag = - emitOpError() << "expects only one non-pattern op in its body"; - diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; - diag.attachNote(op.getLoc()) << "second non-pattern op"; - return diag; - } - topLevelOp = &op; - continue; - } - - InFlightDiagnostic diag = - emitOpError() - << "expects only pattern and top-level transform ops in its body"; - diag.attachNote(op.getLoc()) << "offending op"; - return diag; - } - - if (auto parent = getOperation()->getParentOfType()) { - InFlightDiagnostic diag = emitOpError() << "cannot be nested"; - diag.attachNote(parent.getLoc()) << "parent operation"; - return diag; - } - - if (!topLevelOp) { - InFlightDiagnostic diag = emitOpError() - << "expects at least one non-pattern op"; - return diag; - } - - return success(); -} - //===----------------------------------------------------------------------===// // PrintOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTransformPDLExtension + PDLExtension.cpp + PDLExtensionOps.cpp + + DEPENDS + MLIRTransformDialectPDLExtensionOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransformDialect + MLIRPDLDialect + MLIRPDLInterpDialect +) diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp @@ -0,0 +1,69 @@ +//===- PDLExtension.cpp - PDL extension for the Transform dialect ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" +#include "mlir/IR/DialectRegistry.h" + +using namespace mlir; + +namespace { +/// Implementation of the TransformHandleTypeInterface for the PDL +/// OperationType. Accepts any payload operation. +struct PDLOperationTypeTransformHandleTypeInterfaceImpl + : public transform::TransformHandleTypeInterface::ExternalModel< + PDLOperationTypeTransformHandleTypeInterfaceImpl, + pdl::OperationType> { + + /// Accept any operation. + DiagnosedSilenceableFailure + checkPayload(Type type, Location loc, ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); + } +}; +} // namespace + +namespace { +/// PDL extension of the Transform dialect. This provides transform operations +/// that connect to PDL matching as well as interfaces for PDL types to be used +/// with Transform dialect operations. +class PDLExtension : public transform::TransformDialectExtension { +public: + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" + >(); + + addDialectDataInitializer( + [](transform::PDLMatchHooks &) {}); + + // Declare PDL as dependent so we can attach an interface to its type in the + // later step. + declareDependentDialect(); + + // PDLInterp is only relevant if we actually apply the transform IR so + // declare it as generated. + declareGeneratedDialect(); + + // Make PDL OperationType usable as a transform dialect type. + addCustomInitializationStep([](MLIRContext *context) { + pdl::OperationType::attachInterface< + PDLOperationTypeTransformHandleTypeInterfaceImpl>(*context); + }); + } +}; +} // namespace + +void mlir::transform::registerPDLExtension(DialectRegistry &dialectRegistry) { + dialectRegistry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -0,0 +1,234 @@ +//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" + +using namespace mlir; + +MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// PatternApplicatorExtension +//===----------------------------------------------------------------------===// + +namespace { +/// A TransformState extension that keeps track of compiled PDL pattern sets. +/// This is intended to be used along the WithPDLPatterns op. The extension +/// can be constructed given an operation that has a SymbolTable trait and +/// contains pdl::PatternOp instances. The patterns are compiled lazily and one +/// by one when requested; this behavior is subject to change. +class PatternApplicatorExtension : public transform::TransformState::Extension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) + + /// Creates the extension for patterns contained in `patternContainer`. + explicit PatternApplicatorExtension(transform::TransformState &state, + Operation *patternContainer) + : Extension(state), patterns(patternContainer) {} + + /// Appends to `results` the operations contained in `root` that matched the + /// PDL pattern with the given name. Note that `root` may or may not be the + /// operation that contains PDL patterns. Reports an error if the pattern + /// cannot be found. Note that when no operations are matched, this still + /// succeeds as long as the pattern exists. + LogicalResult findAllMatches(StringRef patternName, Operation *root, + SmallVectorImpl &results); + +private: + /// Map from the pattern name to a singleton set of rewrite patterns that only + /// contains the pattern with this name. Populated when the pattern is first + /// requested. + // TODO: reconsider the efficiency of this storage when more usage data is + // available. Storing individual patterns in a set and triggering compilation + // for each of them has overhead. So does compiling a large set of patterns + // only to apply a handful of them. + llvm::StringMap compiledPatterns; + + /// A symbol table operation containing the relevant PDL patterns. + SymbolTable patterns; +}; + +LogicalResult PatternApplicatorExtension::findAllMatches( + StringRef patternName, Operation *root, + SmallVectorImpl &results) { + auto it = compiledPatterns.find(patternName); + if (it == compiledPatterns.end()) { + auto patternOp = patterns.lookup(patternName); + if (!patternOp) + return failure(); + + // Copy the pattern operation into a new module that is compiled and + // consumed by the PDL interpreter. + OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); + auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); + builder.clone(*patternOp); + PDLPatternModule patternModule(std::move(pdlModuleOp)); + + // Merge in the hooks owned by the dialect. Make a copy as they may be + // also used by the following operations. + auto *dialect = + root->getContext()->getLoadedDialect(); + for (const auto &[name, constraintFn] : + dialect->getExtraData() + .getPDLConstraintHooks()) { + patternModule.registerConstraintFunction(name, constraintFn); + } + + // Register a noop rewriter because PDL requires patterns to end with some + // rewrite call. + patternModule.registerRewriteFunction( + "transform.dialect", [](PatternRewriter &, Operation *) {}); + + it = compiledPatterns + .try_emplace(patternOp.getName(), std::move(patternModule)) + .first; + } + + PatternApplicator applicator(it->second); + // We want to discourage direct use of PatternRewriter in APIs but In this + // very specific case, an IRRewriter is not enough. + struct TrivialPatternRewriter : public PatternRewriter { + public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} + }; + TrivialPatternRewriter rewriter(root->getContext()); + applicator.applyDefaultCostModel(); + root->walk([&](Operation *op) { + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + results.push_back(op); + }); + + return success(); +} +} // namespace + +//===----------------------------------------------------------------------===// +// PDLMatchHooks +//===----------------------------------------------------------------------===// + +void transform::PDLMatchHooks::mergeInPDLMatchHooks( + llvm::StringMap &&constraintFns) { + // Steal the constraint functions from the given map. + for (auto &it : constraintFns) + pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); +} + +const llvm::StringMap & +transform::PDLMatchHooks::getPDLConstraintHooks() const { + return pdlMatchHooks.getConstraintFunctions(); +} + +//===----------------------------------------------------------------------===// +// PDLMatchOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::PDLMatchOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + auto *extension = state.getExtension(); + assert(extension && + "expected PatternApplicatorExtension to be attached by the parent op"); + SmallVector targets; + for (Operation *root : state.getPayloadOps(getRoot())) { + if (failed(extension->findAllMatches( + getPatternName().getLeafReference().getValue(), root, targets))) { + emitDefiniteFailure() + << "could not find pattern '" << getPatternName() << "'"; + } + } + results.set(llvm::cast(getResult()), targets); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PDLMatchOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getRoot(), effects); + producesHandle(getMatched(), effects); + onlyReadsPayload(effects); +} + +//===----------------------------------------------------------------------===// +// WithPDLPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::WithPDLPatternsOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + TransformOpInterface transformOp = nullptr; + for (Operation &nested : getBody().front()) { + if (!isa(nested)) { + transformOp = cast(nested); + break; + } + } + + state.addExtension(getOperation()); + auto guard = llvm::make_scope_exit( + [&]() { state.removeExtension(); }); + + auto scope = state.make_region_scope(getBody()); + if (failed(mapBlockArguments(state))) + return DiagnosedSilenceableFailure::definiteFailure(); + return state.applyTransform(transformOp); +} + +void transform::WithPDLPatternsOp::getEffects( + SmallVectorImpl &effects) { + getPotentialTopLevelEffects(effects); +} + +LogicalResult transform::WithPDLPatternsOp::verify() { + Block *body = getBodyBlock(); + Operation *topLevelOp = nullptr; + for (Operation &op : body->getOperations()) { + if (isa(op)) + continue; + + if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { + if (topLevelOp) { + InFlightDiagnostic diag = + emitOpError() << "expects only one non-pattern op in its body"; + diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; + diag.attachNote(op.getLoc()) << "second non-pattern op"; + return diag; + } + topLevelOp = &op; + continue; + } + + InFlightDiagnostic diag = + emitOpError() + << "expects only pattern and top-level transform ops in its body"; + diag.attachNote(op.getLoc()) << "offending op"; + return diag; + } + + if (auto parent = getOperation()->getParentOfType()) { + InFlightDiagnostic diag = emitOpError() << "cannot be nested"; + diag.attachNote(parent.getLoc()) << "parent operation"; + return diag; + } + + if (!topLevelOp) { + InFlightDiagnostic diag = emitOpError() + << "expects at least one non-pattern op"; + return diag; + } + + return success(); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -114,6 +114,16 @@ DIALECT_NAME linalg DEPENDS LinalgOdsGen) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformPDLExtensionOps.td + SOURCES + dialects/_transform_pdl_extension_ops_ext.py + dialects/transform/pdl.py + DIALECT_NAME transform + EXTENSION_NAME transform_pdl_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td @@ -0,0 +1,20 @@ +//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the PDL extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -60,26 +60,6 @@ ) -class PDLMatchOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - pattern_name, - loc=loc, - ip=ip, - ) - - class ReplicateOp: def __init__( @@ -152,28 +132,6 @@ return self.body.arguments[1:] -class WithPDLPatternsOp: - - def __init__(self, - target: Union[Operation, Value, Type], - *, - loc=None, - ip=None): - root = _get_op_result_or_value(target) if not isinstance(target, - Type) else None - root_type = target if isinstance(target, Type) else root.type - super().__init__(root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(root_type) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - class YieldOp: def __init__( diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py @@ -0,0 +1,55 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + +class PDLMatchOp: + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + pattern_name, + loc=loc, + ip=ip, + ) + + +class WithPDLPatternsOp: + + def __init__(self, + target: Union[Operation, Value, Type], + *, + loc=None, + ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, + Type) else None + root_type = target if isinstance(target, Type) else root.type + super().__init__(root=root, loc=loc, ip=ip) + self.regions[0].blocks.append(root_type) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/pdl.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_pdl_extension_ops_gen import * diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -83,33 +83,6 @@ // ----- -transform.with_pdl_patterns { -^bb0(%arg0: !transform.any_op): - sequence %arg0 : !transform.any_op failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op - test_print_remark_at_operand %0, "matched" : !transform.any_op - } - - pdl.pattern @some : benefit(1) { - %0 = pdl.operation "test.some_op" - pdl.rewrite %0 with "transform.dialect" - } - - pdl.pattern @other : benefit(1) { - %0 = pdl.operation "test.other_op" - pdl.rewrite %0 with "transform.dialect" - } -} - -// expected-remark @below {{matched}} -"test.some_op"() : () -> () -"test.other_op"() : () -> () -// expected-remark @below {{matched}} -"test.some_op"() : () -> () - -// ----- - // expected-remark @below {{parent function}} func.func @foo() { %0 = arith.constant 0 : i32 diff --git a/mlir/test/Dialect/Transform/test-pdl-extension.mlir b/mlir/test/Dialect/Transform/test-pdl-extension.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-pdl-extension.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics + +transform.with_pdl_patterns { +^bb0(%arg0: !transform.any_op): + sequence %arg0 : !transform.any_op failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %0, "matched" : !transform.any_op + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.rewrite %0 with "transform.dialect" + } + + pdl.pattern @other : benefit(1) { + %0 = pdl.operation "test.other_op" + pdl.rewrite %0 with "transform.dialect" + } +} + +// expected-remark @below {{matched}} +"test.some_op"() : () -> () +"test.other_op"() : () -> () +// expected-remark @below {{matched}} +"test.some_op"() : () -> () + + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !transform.any_op): + sequence %arg0 : !transform.any_op failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.apply_native_constraint "verbose_constraint"(%0 : !pdl.operation) + pdl.rewrite %0 with "transform.dialect" + } +} + +// expected-warning @below {{from PDL constraint}} +"test.some_op"() : () -> () +"test.other_op"() : () -> () diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -21,4 +21,5 @@ MLIRPDLDialect MLIRTransformDialect MLIRTransformDialectTransforms + MLIRTransformPDLExtension ) diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -17,7 +17,9 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" @@ -754,6 +756,23 @@ #define GET_TYPEDEF_LIST #include "TestTransformDialectExtensionTypes.cpp.inc" >(); + + auto verboseConstraint = [](PatternRewriter &rewriter, + ArrayRef pdlValues) { + for (const PDLValue &pdlValue : pdlValues) { + if (Operation *op = pdlValue.dyn_cast()) { + op->emitWarning() << "from PDL constraint"; + } + } + return success(); + }; + + addDialectDataInitializer( + [&](transform::PDLMatchHooks &hooks) { + llvm::StringMap constraints; + constraints.try_emplace("verbose_constraint", verboseConstraint); + hooks.mergeInPDLMatchHooks(std::move(constraints)); + }); } }; } // namespace diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -2,7 +2,7 @@ from mlir.ir import * from mlir.dialects import transform -from mlir.dialects import pdl +from mlir.dialects.transform import pdl as transform_pdl def run(f): @@ -103,13 +103,13 @@ @run def testTransformPDLOps(): - withPdl = transform.WithPDLPatternsOp(transform.AnyOpType.get()) + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) with InsertionPoint(withPdl.body): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [transform.AnyOpType.get()], withPdl.bodyTarget) with InsertionPoint(sequence.body): - match = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher") + match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher") transform.YieldOp(match) # CHECK-LABEL: TEST: testTransformPDLOps # CHECK: transform.with_pdl_patterns { @@ -148,13 +148,13 @@ @run def testReplicateOp(): - with_pdl = transform.WithPDLPatternsOp(transform.AnyOpType.get()) + with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) with InsertionPoint(sequence.body): - m1 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first") - m2 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second") + m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first") + m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second") transform.ReplicateOp(m1, [m2]) transform.YieldOp() # CHECK-LABEL: TEST: testReplicateOp diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -4,6 +4,7 @@ from mlir.dialects import transform from mlir.dialects import pdl from mlir.dialects.transform import structured +from mlir.dialects.transform import pdl as transform_pdl def run(f): @@ -151,13 +152,13 @@ @run def testTileDynamic(): - with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get()) + with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) with InsertionPoint(sequence.body): - m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") - m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") + m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") + m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) transform.YieldOp() diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7495,6 +7495,7 @@ ":TosaToLinalg", ":TransformDialect", ":TransformDialectTransforms", + ":TransformPDLExtension", ":Transforms", ":TransformsPassIncGen", ":VectorDialect", @@ -9732,7 +9733,6 @@ ":ControlFlowInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", - ":PDLDialectTdFiles", ":SideEffectInterfacesTdFiles", ], ) @@ -9889,8 +9889,6 @@ ":CallOpInterfaces", ":ControlFlowInterfaces", ":IR", - ":PDLDialect", - ":PDLInterpDialect", ":Rewrite", ":SideEffectInterfaces", ":Support", @@ -9906,6 +9904,54 @@ ], ) +td_library( + name = "TransformPDLExtensionTdFiles", + srcs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.td"]), + deps = [ + ":PDLDialectTdFiles", + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "TransformPDLExtensionOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td", + deps = [":TransformPDLExtensionTdFiles"], +) + +cc_library( + name = "TransformPDLExtension", + srcs = glob(["lib/Dialect/Transform/PDLExtension/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.h"]), + deps = [ + ":IR", + ":PDLDialect", + ":PDLInterpDialect", + ":SideEffectInterfaces", + ":Support", + ":TransformDialect", + ":TransformPDLExtensionOpsIncGen", + ":Rewrite", + "//llvm:Support", + ], +) + td_library( name = "TransformDialectTransformsTdFiles", srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]), diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -927,6 +927,26 @@ ], ) +gentbl_filegroup( + name = "PDLTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=transform_pdl_extension", + ], + "mlir/dialects/_transform_pdl_extension_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/TransformPDLExtensionOps.td", + deps = [ + ":TransformOpsPyTdFiles", + "//mlir:TransformPDLExtensionTdFiles", + ], +) + filegroup( name = "TransformOpsPyFiles", srcs = [ @@ -934,6 +954,7 @@ "mlir/dialects/_structured_transform_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", ":LoopTransformOpsPyGen", + ":PDLTransformOpsPyGen", ":StructuredTransformOpsPyGen", ":TransformOpsPyGen", ], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -317,6 +317,7 @@ ":TransformDialectTdFiles", "//mlir:PDLDialectTdFiles", "//mlir:TransformDialectTdFiles", + "//mlir:TransformPDLExtension", ], ) @@ -333,6 +334,7 @@ "//mlir:Pass", "//mlir:TransformDialect", "//mlir:TransformDialectTransforms", + "//mlir:TransformPDLExtension", ], )