diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -9,9 +9,13 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" @@ -57,6 +61,7 @@ loader(context); for (const Initializer &init : opInitializers) init(transformDialect); + transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns)); } protected: @@ -88,9 +93,30 @@ [](MLIRContext *context) { context->loadDialect(); }); } + /// Injects the named constraint to make it available for use with the + /// PDLMatchOp in the transform dialect. + void registerPDLMatchConstraintFn(StringRef name, + PDLConstraintFunction &&fn) { + pdlMatchConstraintFns.try_emplace(name, + std::forward(fn)); + } + template + void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) { + pdlMatchConstraintFns.try_emplace( + name, ::mlir::detail::pdl_function_builder::buildConstraintFn( + std::forward(fn))); + } + private: SmallVector opInitializers; SmallVector dialectLoaders; + + /// A list of constraints that should be made availble to PDL patterns + /// processed by PDLMatchOp in the Transform dialect. + /// + /// Declared as mutable so its contents can be moved in the `apply` const + /// method, which is only called once. + mutable llvm::StringMap pdlMatchConstraintFns; }; } // namespace transform diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -84,6 +84,13 @@ `LoopTransformDialectExtension` in the cases above. Unprefixed operation names are reserved for ops defined directly in the Transform dialect. + Overall, Transform IR ops are expected to be contained in a single top-level + op. Such top-level ops specifie how to apply the transformations described + by operations they contain, e.g., `transform.sequence` executes + transformations one by one and fails if any of them fails. Such ops are + expected to have the `PossibleTopLevelTransformOpTrait` and may be used + without arguments. + ## Intended Use and Integrations The transformation control infrastructure provided by this dialect is @@ -163,13 +170,32 @@ let cppNamespace = "::mlir::transform"; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let dependentDialects = [ + "::mlir::pdl::PDLDialect", + "::mlir::pdl_interp::PDLInterpDialect", + ]; + let extraClassDeclaration = [{ - // Make addOperations available to the TransformDialectExtension class. + /// Returns the named PDL constraint functions available in the dialect + /// as a map from their name to the function. + const ::llvm::StringMap<::mlir::PDLConstraintFunction> & + getPDLConstraintHooks() const; + private: + // Make addOperations available to the TransformDialectExtension class. using ::mlir::Dialect::addOperations; template friend class TransformDialectExtension; + + /// Takes ownership of the named PDL constraint function from the given + /// map and makes them available for use by the operations in the dialect. + void mergeInPDLMatchHooks( + ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns); + + /// A container for PDL constraint function that can be used by + /// operations in this dialect. + PDLPatternModule pdlMatchHooks; }]; } @@ -178,4 +204,12 @@ class TransformDialectOp traits = []> : Op; +// Trait for operations that may be top-level operations in Transform IR. +// Operations must have one single-block region and must be usable without +// operands. See the C++ definition of the trait for more information. +def PossibleTopLevelTransformOpTrait + : NativeOpTrait<"PossibleTopLevelTransformOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -140,6 +140,89 @@ }; friend class RegionScope; + /// Base class for TransformState extensions that allow TransformState to + /// contain user-specified information in the state object. Clients are + /// expected to derive this class, add the desired fields, and make the + /// derived class compatible with the MLIR TypeID mechanism: + /// + /// ```mlir + /// class MyExtension final : public TransformState::Extension { + /// public: + /// MyExtension(TranfsormState &state, int myData) + /// : Extension(state) {...} + /// private: + /// int mySupplementaryData; + /// }; + /// ``` + /// + /// Instances of this and derived classes are not expected to be created by + /// the user, instead they are directly constructed within a TransformState. A + /// TransformState can only contain one extension with the given TypeID. + /// Extensions can be obtained from a TransformState instance, and can be + /// removed when they are no longer required. + /// + /// ```mlir + /// transformState.addExtension(/*myData=*/42); + /// MyExtension *ext = transformState.getExtension(); + /// ext->doSomething(); + /// ``` + class Extension { + // Allow TransformState to allocate Extensions. + friend class TransformState; + + public: + /// Base virtual destructor. + // Out-of-line definition ensures symbols are emitted in a single object + // file. + virtual ~Extension(); + + protected: + /// Constructs an extension of the given TransformState object. + Extension(TransformState &state) : state(state) {} + + private: + /// Back-reference to the state that is being extended. + TransformState &state; + }; + + /// Adds a new Extension of the type specified as template parameter, + /// constructing it with the arguments provided. The extension is owned by the + /// TransformState. It is expected that the state does not already have an + /// extension of the same type. Extension constructors are expected to take + /// a reference to TransformState as first argument, automatically supplied + /// by this call. + template + Ty &addExtension(Args &&...args) { + static_assert( + std::is_base_of::value, + "only an class derived from TransformState::Extension is allowed here"); + auto ptr = std::make_unique(*this, std::forward(args)...); + auto result = extensions.try_emplace(TypeID::get(), std::move(ptr)); + assert(result.second && "extension already added"); + return *static_cast(result.first->second.get()); + } + + /// Returns the extension of the specified type. + template + Ty *getExtension() { + static_assert( + std::is_base_of::value, + "only an class derived from TransformState::Extension is allowed here"); + auto iter = extensions.find(TypeID::get()); + if (iter == extensions.end()) + return nullptr; + return static_cast(iter->second.get()); + } + + /// Removes the extension of the specified type. + template + void removeExtension() { + static_assert( + std::is_base_of::value, + "only an class derived from TransformState::Extension is allowed here"); + extensions.erase(TypeID::get()); + } + private: /// Identifier for storing top-level value in the `operations` mapping. static constexpr Value kTopLevelValue = Value(); @@ -196,6 +279,10 @@ /// the region in which the transform IR values are defined. llvm::SmallDenseMap mappings; + /// Extensions attached to the TransformState, identified by the TypeID of + /// their type. Only one extension of any given type is allowed. + DenseMap> extensions; + /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; @@ -241,6 +328,54 @@ return RegionScope(*this, region); } +namespace detail { +/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait +/// to either the list of operations associated with its operand or the root of +/// the payload IR, depending on what is available in the context. +LogicalResult +mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op); + +/// Verification hook for PossibleTopLevelTransformOpTrait. +LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); +} // namespace detail + +/// This trait is supposed to be attached to Transform dialect operations that +/// can be standalone top-level transforms. Such operations typically contain +/// other Transform dialect operations that can be executed following some +/// control flow logic specific to the current operation. The operations with +/// this trait are expected to have exactly one single-block region with one +/// argument of PDL Operation type. The operations are also expected to be valid +/// without operands, in which case they are considered top-level, and with one +/// or more arguments, in which case they are considered nested. Top-level +/// operations have the block argument of the entry block in the Transform IR +/// correspond to the root operation of Payload IR. Nested operations have the +/// block argument of the entry block in the Transform IR correspond to a list +/// of Payload IR operations mapped to the first operand of the Transform IR +/// operation. The operation must implement TransformOpInterface. +template +class PossibleTopLevelTransformOpTrait + : public OpTrait::TraitBase { +public: + /// Verifies that `op` satisfies the invariants of this trait. Not expected to + /// be called directly. + static LogicalResult verifyTrait(Operation *op) { + return detail::verifyPossibleTopLevelTransformOpTrait(op); + } + + /// Returns the single block of the op's only region. + Block *getBodyBlock() { return &this->getOperation()->getRegion(0).front(); } + + /// Sets up the mapping between the entry block of the only region of this op + /// and the relevant list of Payload IR operations in the given state. The + /// state is expected to be already scoped at the region of this operation. + /// Returns failure if the mapping failed, e.g., the value is already mapped. + LogicalResult mapBlockArguments(TransformState &state) { + return detail::mapPossibleTopLevelTransformOpBlockArguments( + state, this->getOperation()); + } +}; + } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -10,12 +10,40 @@ #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +def PDLMatchOp : TransformDialectOp<"pdl_match", + [DeclareOpInterfaceMethods]> { + let summary = "Finds ops that match the named PDL pattern"; + let description = [{ + Find Payload IR ops nested within the Payload IR op associated with the + operand that match the PDL pattern identified by its name. The pattern is + expected to be defined in the closest surrounding `WithPDLPatternsOp`. + + Produces a Transform IR value associated with the list of Payload IR ops + that matched the pattern. The order of results in the list is that of the + Operation::walk, clients are advised not to rely on a specific order though. + If the operand is assocaited with multiple Payload IR ops, finds matching + ops nested within each of those and produces a single list containing all + of the matched ops. + + The tranfsormation is considered successful regardless of whether some + Payload IR ops actually matched the pattern and only fails if the pattern + could not be looked up or compiled. + }]; + + let arguments = (ins PDL_Operation:$root, SymbolRefAttr:$pattern_name); + let results = (outs PDL_Operation:$matched); + + let assemblyFormat = "$pattern_name `in` $root attr-dict"; +} + def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods, OpAsmOpInterface, + PossibleTopLevelTransformOpTrait, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { let summary = "Contains a sequence of other transform ops to apply"; let description = [{ @@ -48,13 +76,60 @@ let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. static StringRef getDefaultDialect() { return "transform"; } + }]; + + let hasVerifier = 1; +} - Block *getBodyBlock() { - return &getBody().front(); +def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", + [DeclareOpInterfaceMethods, NoTerminator, + OpAsmOpInterface, PossibleTopLevelTransformOpTrait, SymbolTable]> { + let summary = "Contains PDL patterns available for use in transforms"; + let description = [{ + This op contains a set of named PDL patterns that are available for the + Transform dialect operations to be used for pattern matching. For example, + PDLMatchOp can be used to produce a Transform IR value associated with all + Payload IR operations that match the pattern as follows: + + ```mlir + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + pdl.pattern @my_pattern : benefit(1) { + %0 = pdl.operation //... + // Regular PDL goes here. + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %1 = pdl_match @my_pattern in %arg1 + // Use %1 as handle + } } + ``` + + Note that the pattern is expected to finish with a `pdl.rewrite` terminator + that points to the custom rewriter named "transform.dialect". The rewriter + actually does nothing, but the transform application will keep track of the + operations that matched the pattern. + + This op is expected to contain `pdl.pattern` operations and exactly one + another Transform dialect operation that gets executed with all patterns + available. This op is a possible top-level Transform IR op, the argument of + its entry block corresponds to either the root op of the payload IR or the + ops associated with its operand when provided. }]; + let arguments = (ins Optional:$root); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "($root^)? attr-dict-with-keyword regions"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "transform"; } + }]; } def YieldOp : TransformDialectOp<"yield", [Terminator]> { diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -19,3 +19,15 @@ #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); } + +void transform::TransformDialect::mergeInPDLMatchHooks( + llvm::StringMap &&constraintFns) { + // Steal the constraint functions form the given map. + for (auto &it : constraintFns) + pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); +} + +const llvm::StringMap & +transform::TransformDialect::getPDLConstraintHooks() const { + return pdlMatchHooks.getConstraintFunctions(); +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/ScopeExit.h" @@ -117,6 +118,8 @@ return success(); } +transform::TransformState::Extension::~Extension() = default; + //===----------------------------------------------------------------------===// // TransformResults //===----------------------------------------------------------------------===// @@ -145,6 +148,61 @@ return segments[resultNumber]; } +//===----------------------------------------------------------------------===// +// Utilities for PossibleTopLevelTransformOpTrait. +//===----------------------------------------------------------------------===// + +LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( + TransformState &state, Operation *op) { + SmallVector targets; + if (op->getNumOperands() != 0) + llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); + else + targets.push_back(state.getTopLevel()); + + return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), + targets); +} + +LogicalResult +transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { + // Attaching this trait without the interface is a misuse of the API, but it + // cannot be caught via a static_assert because interface registration is + // dynamic. + assert(isa(op) && + "should implement TransformOpInterface to have " + "PossibleTopLevelTransformOpTrait"); + + if (op->getNumRegions() != 1) + return op->emitOpError() << "expects one region"; + + Region *bodyRegion = &op->getRegion(0); + if (!llvm::hasNItems(*bodyRegion, 1)) + return op->emitOpError() << "expects a single-block region"; + + Block *body = &bodyRegion->front(); + if (body->getNumArguments() != 1 || + !body->getArgumentTypes()[0].isa()) { + return op->emitOpError() + << "expects the entry block to have one argument of type " + << pdl::OperationType::get(op->getContext()); + } + + if (auto *parent = + op->getParentWithTrait()) { + if (op->getNumOperands() == 0) { + InFlightDiagnostic diag = + op->emitOpError() + << "expects the root operation to be provided for a nested op"; + diag.attachNote(parent->getLoc()) + << "nested in another possible top-level op"; + return diag; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // Generated interface implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -7,26 +7,142 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/IR/Builders.h" - #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" using namespace mlir; #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" -LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, +//===----------------------------------------------------------------------===// +// PatternApplicatorExtension +//===----------------------------------------------------------------------===// + +namespace { +/// A simple pattern rewriter that can be constructed from a context. This is +/// necessary to apply patterns to a specific op locally. +class TrivialPatternRewriter : public PatternRewriter { +public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} +}; + +/// A TransformState extension that keeps track of compiled PDL pattern sets. +/// This is intended to be used along the WithPDLPatterns op. The extension +/// can be constructed given an operation that has a SymbolTable trait and +/// contains pdl::PatternOp instances. The patterns are compiled lazily and one +/// by one when requested; this behavior is subject to change. +class PatternApplicatorExtension : public transform::TransformState::Extension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) + + /// Creates the extension for patterns contained in `patternContainer`. + explicit PatternApplicatorExtension(transform::TransformState &state, + Operation *patternContainer) + : Extension(state), patterns(patternContainer) {} + + /// Appends to `results` the operations contained in `root` that matched the + /// PDL pattern with the given name. Note that `root` may or may not be the + /// operation that contains PDL patterns. Reports an error if the pattern + /// cannot be found. Note that when no operations are matched, this still + /// succeeds as long as the pattern exists. + LogicalResult findAllMatches(StringRef patternName, Operation *root, + SmallVectorImpl &results); + +private: + /// Map from the pattern name to a singleton set of rewrite patterns that only + /// contains the pattern with this name. Populated when the pattern is first + /// requested. + // TODO: reconsider the efficiency of this storage when more usage data is + // available. Storing individual patterns in a set and triggering compilation + // for each of them has overhead. So does compiling a large set of patterns + // only to apply a handlful of them. + llvm::StringMap compiledPatterns; + + /// A symbol table operation containing the relevant PDL patterns. + SymbolTable patterns; +}; + +LogicalResult PatternApplicatorExtension::findAllMatches( + StringRef patternName, Operation *root, + SmallVectorImpl &results) { + auto it = compiledPatterns.find(patternName); + if (it == compiledPatterns.end()) { + auto patternOp = patterns.lookup(patternName); + if (!patternOp) + return failure(); + + OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); + patternOp->moveBefore(pdlModuleOp->getBody(), + pdlModuleOp->getBody()->end()); + PDLPatternModule patternModule(std::move(pdlModuleOp)); + + // Merge in the hooks owned by the dialect. Make a copy as they may be + // also used by the following operations. + auto *dialect = + root->getContext()->getLoadedDialect(); + for (const auto &pair : dialect->getPDLConstraintHooks()) + patternModule.registerConstraintFunction(pair.first(), pair.second); + + // Register a noop rewriter because PDL requires patterns to end with some + // rewrite call. + patternModule.registerRewriteFunction( + "transform.dialect", [](PatternRewriter &, Operation *) {}); + + it = compiledPatterns + .try_emplace(patternOp.getName(), std::move(patternModule)) + .first; + } + + PatternApplicator applicator(it->second); + TrivialPatternRewriter rewriter(root->getContext()); + applicator.applyDefaultCostModel(); + root->walk([&](Operation *op) { + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + results.push_back(op); + }); + + return success(); +} +} // namespace + +//===----------------------------------------------------------------------===// +// PDLMatchOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { + auto *extension = state.getExtension(); + assert(extension && + "expected PatternApplicatorExtension to be attached by the parent op"); SmallVector targets; - if (getRoot()) - llvm::append_range(targets, state.getPayloadOps(getRoot())); - else - targets.push_back(state.getTopLevel()); + for (Operation *root : state.getPayloadOps(getRoot())) { + if (failed(extension->findAllMatches( + getPatternName().getLeafReference().getValue(), root, targets))) { + return emitOpError() << "could not find pattern '" << getPatternName() + << "'"; + } + } + results.set(getResult().cast(), targets); + return success(); +} +//===----------------------------------------------------------------------===// +// SequenceOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, + transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); - if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets))) + if (failed(mapBlockArguments(state))) return failure(); // Apply the sequenced ops one by one. @@ -48,23 +164,6 @@ } LogicalResult transform::SequenceOp::verify() { - if (getBodyBlock()->getNumArguments() != 1 || - !getBodyBlock()->getArgumentTypes()[0].isa()) { - return emitOpError() - << "expected the entry block to have one argument of type " - << pdl::OperationType::get(getContext()); - } - - if (auto parent = getOperation()->getParentOfType()) { - if (!getRoot()) { - InFlightDiagnostic diag = - emitOpError() - << "expected the root operation to be provided for a nested sequence"; - diag.attachNote(parent.getLoc()) << "nested in another sequence"; - return diag; - } - } - for (Operation &child : *getBodyBlock()) { if (!isa(child) && &child != &getBodyBlock()->back()) { @@ -99,3 +198,65 @@ } return success(); } + +//===----------------------------------------------------------------------===// +// WithPDLPatternsOp +//===----------------------------------------------------------------------===// + +LogicalResult +transform::WithPDLPatternsOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + OwningOpRef pdlModuleOp = + ModuleOp::create(getOperation()->getLoc()); + TransformOpInterface transformOp = nullptr; + for (Operation &nested : getBody().front()) { + if (!isa(nested)) { + transformOp = cast(nested); + break; + } + } + + state.addExtension(getOperation()); + auto guard = llvm::make_scope_exit( + [&]() { state.removeExtension(); }); + + auto scope = state.make_region_scope(getBody()); + if (failed(mapBlockArguments(state))) + return failure(); + return state.applyTransform(transformOp); +} + +LogicalResult transform::WithPDLPatternsOp::verify() { + Block *body = getBodyBlock(); + Operation *topLevelOp = nullptr; + for (Operation &op : body->getOperations()) { + if (isa(op)) + continue; + + if (op.hasTrait()) { + if (topLevelOp) { + InFlightDiagnostic diag = + emitOpError() << "expects only one non-pattern op in its body"; + diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; + diag.attachNote(op.getLoc()) << "second non-pattern op"; + return diag; + } + topLevelOp = &op; + continue; + } + + InFlightDiagnostic diag = + emitOpError() + << "expects only pattern and top-level transform ops in its body"; + diag.attachNote(op.getLoc()) << "offending op"; + return diag; + } + + if (auto parent = getOperation()->getParentOfType()) { + InFlightDiagnostic diag = emitOpError() << "cannot be nested"; + diag.attachNote(parent.getLoc()) << "parent operation"; + return diag; + } + + return success(); +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -1,15 +1,15 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}} +// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}} transform.sequence { } // ----- -// expected-note @below {{nested in another sequence}} +// expected-note @below {{nested in another possible top-level op}} transform.sequence { ^bb0(%arg0: !pdl.operation): - // expected-error @below {{expected the root operation to be provided for a nested sequence}} + // expected-error @below {{expects the root operation to be provided for a nested op}} transform.sequence { ^bb1(%arg1: !pdl.operation): } @@ -50,3 +50,64 @@ // expected-note @below {{terminator}} transform.yield } : !pdl.operation + +// ----- + +// expected-note @below {{nested in another possible top-level op}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{expects the root operation to be provided for a nested op}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects only one non-pattern op in its body}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-note @below {{first non-pattern op}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } + // expected-note @below {{second non-pattern op}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects only pattern and top-level transform ops in its body}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-note @below {{offending op}} + "test.something"() : () -> () +} + +// ----- + +// expected-note @below {{parent operation}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{op cannot be nested}} + transform.with_pdl_patterns %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects one region}} +"transform.test_transform_unrestricted_op_no_interface"() : () -> () + +// ----- + +// expected-error @below {{expects a single-block region}} +"transform.test_transform_unrestricted_op_no_interface"() ({ +^bb0(%arg0: !pdl.operation): + "test.potential_terminator"() : () -> () +^bb1: + "test.potential_terminator"() : () -> () +}) : () -> () diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -10,3 +10,23 @@ ^bb1(%arg1: !pdl.operation): } } + +// CHECK: transform.with_pdl_patterns +// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation): +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // CHECK: sequence %[[ARG]] + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} + +// CHECK: transform.sequence +// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation): +transform.sequence { +^bb0(%arg0: !pdl.operation): + // CHECK: with_pdl_patterns %[[ARG]] + with_pdl_patterns %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -69,3 +69,31 @@ // expected-remark @below {{succeeded}} test_consume_operand_if_matches_param_or_fail %0[42] } + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + test_print_remark_at_operand %0, "matched" + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.rewrite %0 with "transform.dialect" + } + + pdl.pattern @other : benefit(1) { + %0 = pdl.operation "test.other_op" + pdl.rewrite %0 with "transform.dialect" + } +} + +// expected-remark @below {{matched}} +"test.some_op"() : () -> () +"test.other_op"() : () -> () +// expected-remark @below {{matched}} +"test.some_op"() : () -> () + diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -22,7 +22,8 @@ namespace { /// Simple transform op defined outside of the dialect. Just emits a remark when -/// applied. +/// applied. This op is defined in C++ to test that C++ definitions also work +/// for op injection into the Transform dialect. class TestTransformOp : public Op { public: @@ -63,6 +64,33 @@ printer << " " << getMessage(); } }; + +/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait +/// in cases where it is attached to ops that do not comply with the trait +/// requirements. This op cannot be defined in ODS because ODS generates strict +/// verifiers that overalp with those in the trait and run earlier. +class TestTransformUnrestrictedOpNoInterface + : public Op { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestTransformUnrestrictedOpNoInterface) + + using Op::Op; + + static ArrayRef getAttributeNames() { return {}; } + + static constexpr llvm::StringLiteral getOperationName() { + return llvm::StringLiteral( + "transform.test_transform_unrestricted_op_no_interface"); + } + + LogicalResult apply(transform::TransformResults &results, + transform::TransformState &state) { + return success(); + } +}; } // namespace LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( @@ -97,6 +125,15 @@ return success(); } +LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + ArrayRef payload = state.getPayloadOps(getOperand()); + for (Operation *op : payload) + op->emitRemark() << getMessage(); + + return success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL @@ -108,6 +145,7 @@ TestTransformDialectExtension() { declareDependentDialect(); registerTransformOps(); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -38,4 +38,12 @@ let cppNamespace = "::mlir::test"; } +def TestPrintRemarkAtOperandOp + : Op]> { + let arguments = (ins PDL_Operation:$operand, StrAttr:$message); + let assemblyFormat = "$operand `,` $message attr-dict"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7772,6 +7772,8 @@ deps = [ ":IR", ":PDLDialect", + ":PDLInterpDialect", + ":Rewrite", ":Support", ":TransformDialectIncGen", ":TransformDialectInterfacesIncGen",