diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -32,11 +32,18 @@ protected: /// Must be called by the subclass with the appropriate type ID. - explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {} + explicit TransformDialectDataBase(TypeID typeID, MLIRContext *ctx) + : typeID(typeID), ctx(ctx) {} + + /// Return the MLIR context. + MLIRContext *getContext() const { return ctx; } private: /// The type ID of the subclass. const TypeID typeID; + + /// The MLIR context. + MLIRContext *ctx; }; } // namespace detail @@ -55,7 +62,8 @@ class TransformDialectData : public detail::TransformDialectDataBase { protected: /// Forward the TypeID of the derived class to the base. - TransformDialectData() : TransformDialectDataBase(TypeID::get()) {} + TransformDialectData(MLIRContext *ctx) + : TransformDialectDataBase(TypeID::get(), ctx) {} }; #ifndef NDEBUG @@ -294,7 +302,8 @@ if (it != extraData.end()) return static_cast(*it->getSecond()); - auto emplaced = extraData.try_emplace(typeID, std::make_unique()); + auto emplaced = + extraData.try_emplace(typeID, std::make_unique(getContext())); return static_cast(*emplaced.first->getSecond()); } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -12,6 +12,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/FunctionInterfaces.h" @@ -25,6 +26,8 @@ namespace mlir { namespace transform { +class ApplyPatternsOp; + enum class FailurePropagationMode : uint32_t; class FailurePropagationModeAttr; @@ -120,9 +123,71 @@ TransformOpInterface transformOp; }; +/// A specialized listener that keeps track of cases in which no replacement +/// payload could be found. The error state of this listener must be checked +/// before the end of its lifetime. +class ErrorCheckingTrackingListener : public TrackingListener { +public: + using transform::TrackingListener::TrackingListener; + + ~ErrorCheckingTrackingListener() override; + + /// Check and return the current error state of this listener. Afterwards, + /// resets the error state to "success". + DiagnosedSilenceableFailure checkAndResetError(); + + /// Return "true" if this tracking listener had a failure. + bool failed() const; + +protected: + void notifyPayloadReplacementNotFound(Operation *op, + ValueRange values) override; + +private: + /// The error state of this listener. "Success" indicates that no error + /// happened so far. + DiagnosedSilenceableFailure status = DiagnosedSilenceableFailure::success(); + + /// The number of errors that have been encountered. + int64_t errorCounter = 0; +}; + +/// The PatternRegistry stores callbacks to functions that populate a +/// `RewritePatternSet`. Registered patterns can be applied with the +/// "transform.apply_patterns" op. +class PatternRegistry : public TransformDialectData { +public: + PatternRegistry(MLIRContext *ctx) : TransformDialectData(ctx), builder(ctx) {} + + /// A function that populates a `RewritePatternSet`. + using PopulatePatternsFn = std::function; + + /// Registers patterns with the specified identifier. The identifier should + /// be prefixed with the dialect to which the patterns belong. + void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn); + +protected: + friend class ApplyPatternsOp; + + /// Returns "true" if patterns are registered with the specified identifier. + bool hasPatterns(StringAttr identifier) const; + + /// Populates the given pattern set with the specified patterns. + void populatePatterns(StringAttr identifier, + RewritePatternSet &patternSet) const; + +private: + /// A builder for creating StringAttrs. + Builder builder; + + DenseMap patterns; +}; + } // namespace transform } // namespace mlir +MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry) + #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -126,6 +126,49 @@ "`:` type($target) (`,` type($param)^)?"; } +def ApplyPatternsOp : TransformDialectOp<"apply_patterns", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods]> { + let summary = "Greedily applies patterns to the body of the targeted op"; + let description = [{ + This transform greedily applies the specified patterns to the body of the + targeted op until a fixpoint was reached. Patterns are not applied to the + targeted op itself. + + Only patterns that were registered in the transform dialect's + `PatternRegistry` are available. Additional patterns can be registered as + part of transform dialect extensions. + + This transform only reads the target handle and modifies the payload. If a + pattern erases or replaces a tracked op, the mapping is updated accordingly. + + Only replacements via `RewriterBase::replaceOp` or `replaceOpWithNewOp` are + considered "payload op replacements". Furthermore, only if the replacement + values are defined by the same op and that op has the same type as the + original op, the mapping is updated. Otherwise, this transform fails + silently unless `fail_on_payload_replacement_not_found` is set to "false". + More details can be found at the documentation site of `TrackingListener`. + + This transform also fails silently if the pattern application did not + converge within the default number of iterations/rewrites of the greedy + pattern rewrite driver. + }]; + + let arguments = (ins + TransformHandleTypeInterface:$target, ArrayAttr:$patterns, + DefaultValuedAttr:$fail_on_payload_replacement_not_found); + let results = (outs); + let assemblyFormat = "$patterns `to` $target attr-dict `:` type($target)"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def CastOp : TransformDialectOp<"cast", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h --- a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h @@ -27,6 +27,8 @@ /// populated by extensions. class PDLMatchHooks : public TransformDialectData { public: + PDLMatchHooks(MLIRContext *ctx) : TransformDialectData(ctx) {} + /// Takes ownership of the named PDL constraint function from the given /// map and makes them available for use by the operations in the dialect. void diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" @@ -31,6 +32,8 @@ using namespace mlir; +MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry) + static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, Type &rootType, @@ -175,6 +178,62 @@ (void)replacePayloadOp(op, replacement); } +transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() { + // The state of the ErrorCheckingTrackingListener must be checked and reset + // if there was an error. This is to prevent errors from accidentally being + // missed. + assert(status.succeeded() && "listener state was not checked"); +} + +DiagnosedSilenceableFailure +transform::ErrorCheckingTrackingListener::checkAndResetError() { + DiagnosedSilenceableFailure s = std::move(status); + status = DiagnosedSilenceableFailure::success(); + errorCounter = 0; + return s; +} + +bool transform::ErrorCheckingTrackingListener::failed() const { + return !status.succeeded(); +} + +void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( + Operation *op, ValueRange values) { + if (status.succeeded()) { + status = emitSilenceableFailure( + getTransformOp(), "tracking listener failed to find replacement op"); + } + + status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op"; + for (auto &&[index, value] : llvm::enumerate(values)) + status.attachNote(value.getLoc()) + << "[" << errorCounter << "] replacement value " << index; + + ++errorCounter; +} + +//===----------------------------------------------------------------------===// +// PatternRegistry +//===----------------------------------------------------------------------===// + +void transform::PatternRegistry::registerPatterns(StringRef identifier, + PopulatePatternsFn &&fn) { + StringAttr attr = builder.getStringAttr(identifier); + assert(!patterns.contains(attr) && "patterns identifier is already in use"); + patterns.try_emplace(attr, std::move(fn)); +} + +void transform::PatternRegistry::populatePatterns( + StringAttr identifier, RewritePatternSet &patternSet) const { + auto it = patterns.find(identifier); + assert(it != patterns.end() && "patterns not registered in registry"); + it->second(patternSet); +} + +bool transform::PatternRegistry::hasPatterns(StringAttr identifier) const { + return patterns.contains(identifier); +} + //===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// @@ -356,6 +415,77 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// ApplyPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyPatternsOp::applyToOne(Operation *target, + ApplyToEachResultList &results, + transform::TransformState &state) { + // Gather all specified patterns. + MLIRContext *ctx = target->getContext(); + RewritePatternSet patterns(ctx); + const auto ®istry = getContext() + ->getLoadedDialect() + ->getExtraData(); + for (Attribute attr : getPatterns()) + registry.populatePatterns(attr.cast(), patterns); + + // Configure the GreedyPatternRewriteDriver. + ErrorCheckingTrackingListener listener(state, *this); + GreedyRewriteConfig config; + config.listener = &listener; + + // Manually gather list of ops because the other GreedyPatternRewriteDriver + // overloads only accepts ops that are isolated from above. This way, patterns + // can be applied to ops that are not isolated from above. + SmallVector ops; + target->walk([&](Operation *nestedOp) { + if (target != nestedOp) + ops.push_back(nestedOp); + }); + LogicalResult result = + applyOpPatternsAndFold(ops, std::move(patterns), config); + // A failure typically indicates that the pattern application did not + // converge. + if (failed(result)) { + return emitSilenceableFailure(target) + << "greedy pattern application failed"; + } + + // Check listener state for tracking errors. + if (listener.failed()) { + DiagnosedSilenceableFailure status = listener.checkAndResetError(); + if (getFailOnPayloadReplacementNotFound()) + return status; + (void)status.silence(); + } + + return DiagnosedSilenceableFailure::success(); +} + +LogicalResult transform::ApplyPatternsOp::verify() { + const auto ®istry = getContext() + ->getLoadedDialect() + ->getExtraData(); + for (Attribute attr : getPatterns()) { + auto strAttr = attr.dyn_cast(); + if (!strAttr) + return emitOpError() << "expected " << getPatternsAttrName() + << " to be an array of strings"; + if (!registry.hasPatterns(strAttr)) + return emitOpError() << "patterns not registered: " << strAttr.strref(); + } + return success(); +} + +void transform::ApplyPatternsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -672,3 +672,19 @@ @match -> @action : (!transform.any_op) -> !transform.any_op } } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}} + transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 : !transform.any_op +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected "patterns" to be an array of strings}} + transform.apply_patterns [3, 9] to %arg0 : !transform.any_op +} diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -0,0 +1,123 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @update_tracked_op_mapping() +// CHECK: "test.container"() ({ +// CHECK: %0 = "test.foo"() {annotated} : () -> i32 +// CHECK: }) : () -> () +func.func @update_tracked_op_mapping() { + "test.container"() ({ + %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32) + }) : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + // Add an attribute to %1, which is now mapped to a new op. + transform.annotate %1 "annotated" : !transform.any_op +} + +// ----- + +func.func @replacement_op_not_found() { + "test.container"() ({ + // expected-note @below {{[0] replaced op}} + // expected-note @below {{[0] replacement value 0}} + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32) + }) : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{tracking listener failed to find replacement op}} + transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + // %1 must be used in some way. If no replacement payload op could be found, + // an error is thrown only if the handle is not dead. + transform.annotate %1 "annotated" : !transform.any_op +} + +// ----- + +// CHECK-LABEL: func @replacement_op_for_dead_handle_not_found() +// CHECK: "test.container"() ({ +// CHECK: %0 = "test.bar"() : () -> i32 +// CHECK: }) : () -> () +func.func @replacement_op_for_dead_handle_not_found() { + "test.container"() ({ + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32) + }) : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // No error because %1 is dead. + transform.apply_patterns ["transform.test"] to %0 : !transform.any_op +} + +// ----- + +// CHECK-LABEL: func @replacement_op_not_found_silenced() +// CHECK: "test.container"() ({ +// CHECK: %0 = "test.bar"() : () -> i32 +// CHECK: }) : () -> () +func.func @replacement_op_not_found_silenced() { + "test.container"() ({ + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32) + }) : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {fail_on_payload_replacement_not_found = false}: !transform.any_op + transform.annotate %1 "annotated" : !transform.any_op +} + +// ----- + +// CHECK-LABEL: func @patterns_apply_only_to_target_body() +// CHECK: %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> i32 +func.func @patterns_apply_only_to_target_body() { + %0 = "test.foo"() {replace_with_new_op = "test.bar"} : () -> (i32) + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns ["transform.test"] to %0 : !transform.any_op +} + +// ----- + +// CHECK-LABEL: func @erase_tracked_op() +// CHECK: "test.container"() ({ +// CHECK-NEXT: ^bb0: +// CHECK-NEXT: }) : () -> () +func.func @erase_tracked_op() { + "test.container"() ({ + // expected-remark @below {{matched op}} + %0 = "test.erase_op"() {replace_with_new_op = "test.foo"} : () -> (i32) + }) : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op + transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -746,6 +746,41 @@ } namespace { +// Test pattern to replace an operation with a new op. +class ReplaceWithNewOp : public RewritePattern { +public: + ReplaceWithNewOp(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto newName = op->getAttrOfType("replace_with_new_op"); + if (!newName) + return failure(); + Operation *newOp = rewriter.create( + op->getLoc(), OperationName(newName, op->getContext()).getIdentifier(), + op->getOperands(), op->getResultTypes()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// Test pattern to erase an operation. +class EraseOp : public RewritePattern { +public: + EraseOp(MLIRContext *context) + : RewritePattern("test.erase_op", /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +void populateTestPatterns(RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL /// types for operands and results. @@ -783,6 +818,11 @@ constraints.try_emplace("verbose_constraint", verboseConstraint); hooks.mergeInPDLMatchHooks(std::move(constraints)); }); + + addDialectDataInitializer( + [&](transform::PatternRegistry ®istry) { + registry.registerPatterns("transform.test", populateTestPatterns); + }); } }; } // namespace