diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -228,4 +228,25 @@ ]; } +def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> { + let description = [{ + This interface should be implemented by ops that select patterns of a + `transform.apply_patterns` op. It provides a method to populate a rewrite + pattern set with patterns. + }]; + + let cppNamespace = "::mlir::transform"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate patterns into the given pattern set. + }], + /*returnType=*/"void", + /*name=*/"populatePatterns", + /*arguments=*/(ins "RewritePatternSet &":$patterns) + >, + ]; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORM_INTERFACES_TD diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/Transform/IR/MatchInterfaces.td" include "mlir/Dialect/Transform/IR/TransformAttrs.td" @@ -128,17 +129,20 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", [TransformOpInterface, TransformEachOpTrait, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods] + # GraphRegionNoTerminator.traits> { let summary = "Greedily applies patterns to the body of the targeted op"; let description = [{ This transform greedily applies the specified patterns to the body of the targeted op until a fixpoint was reached. Patterns are not applied to the targeted op itself. - Only patterns that were registered in the transform dialect's - `PatternRegistry` are available. Additional patterns can be registered as - part of transform dialect extensions. "canonicalization" is a special set - of patterns that refers to all canonicalization patterns of all loaded + The patterns that should be applied are specified in the graph region of + this op. They must implement the `PatternDescriptorOpInterface`. + + (Deprecated) In addition, patterns that were registered in the transform + dialect's `PatternRegistry` are available. "canonicalization" is a special + set of patterns that refers to all canonicalization patterns of all loaded dialects. This transform only reads the target handle and modifies the payload. If a @@ -160,7 +164,9 @@ TransformHandleTypeInterface:$target, ArrayAttr:$patterns, DefaultValuedAttr:$fail_on_payload_replacement_not_found); let results = (outs); - let assemblyFormat = "$patterns `to` $target attr-dict `:` type($target)"; + let regions = (region MaxSizedRegion<1>:$region); + + let assemblyFormat = "$patterns `to` $target $region attr-dict `:` type($target)"; let hasVerifier = 1; let extraClassDeclaration = [{ @@ -171,6 +177,17 @@ }]; } +def ApplyCanonicalizationPatternsOp + : TransformDialectOp<"apply_patterns.canonicalization", + [DeclareOpInterfaceMethods]> { + let summary = "Populates canonicalization patterns"; + let description = [{ + This op populates all canonicalization patterns of all loaded dialects in + an `apply_patterns` transform. + }]; + let assemblyFormat = "attr-dict"; +} + def CastOp : TransformDialectOp<"cast", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -27,12 +27,15 @@ RegisteredOperationName opName = *RegisteredOperationName::lookup(name, context); assert((opName.hasInterface() || + opName.hasInterface() || opName.hasTrait()) && "non-terminator ops injected into the transform dialect must " - "implement TransformOpInterface"); - assert(opName.hasInterface() && - "ops injected into the transform dialect must implement " - "MemoryEffectsOpInterface"); + "implement TransformOpInterface or PatternDescriptorOpInterface"); + if (!opName.hasInterface()) { + assert(opName.hasInterface() && + "ops injected into the transform dialect must implement " + "MemoryEffectsOpInterface"); + } } void transform::detail::checkImplementsTransformHandleTypeInterface( @@ -57,16 +60,6 @@ #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); initializeTypes(); - - // Register all canonicalization patterns. - getOrCreateExtraData().registerPatterns( - "canonicalization", [](RewritePatternSet &patterns) { - MLIRContext *ctx = patterns.getContext(); - for (Dialect *dialect : ctx->getLoadedDialects()) - dialect->getCanonicalizationPatterns(patterns); - for (RegisteredOperationName op : ctx->getRegisteredOperations()) - op.getCanonicalizationPatterns(patterns, ctx); - }); } Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -445,6 +445,12 @@ ->getExtraData(); for (Attribute attr : getPatterns()) registry.populatePatterns(attr.cast(), patterns); + if (!getRegion().empty()) { + for (Operation &op : getRegion().front()) { + cast(&op).populatePatterns( + patterns); + } + } // Configure the GreedyPatternRewriteDriver. ErrorCheckingTrackingListener listener(state, *this); @@ -491,6 +497,17 @@ if (!registry.hasPatterns(strAttr)) return emitOpError() << "patterns not registered: " << strAttr.strref(); } + if (!getRegion().empty()) { + for (Operation &op : getRegion().front()) { + if (!isa(&op)) { + InFlightDiagnostic diag = emitOpError() + << "expected children ops to implement " + "PatternDescriptorOpInterface"; + diag.attachNote(op.getLoc()) << "op without interface"; + return diag; + } + } + } return success(); } @@ -500,6 +517,19 @@ transform::modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// ApplyCanonicalizationPatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyCanonicalizationPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + for (Dialect *dialect : ctx->getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -678,7 +678,7 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}} - transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 : !transform.any_op + transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 {} : !transform.any_op } // ----- @@ -686,5 +686,15 @@ transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): // expected-error @below {{expected "patterns" to be an array of strings}} - transform.apply_patterns [3, 9] to %arg0 : !transform.any_op + transform.apply_patterns [3, 9] to %arg0 {} : !transform.any_op +} + +// ----- +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected children ops to implement PatternDescriptorOpInterface}} + transform.apply_patterns [] to %arg0 { + // expected-note @below {{op without interface}} + transform.named_sequence @foo() + } : !transform.any_op } diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -15,7 +15,31 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + // Add an attribute to %1, which is now mapped to a new op. + transform.annotate %1 "annotated" : !transform.any_op +} + +// ----- + +// CHECK-LABEL: func @update_tracked_op_mapping_region() +// CHECK: "test.container"() ({ +// CHECK: %0 = "test.foo"() {annotated} : () -> i32 +// CHECK: }) : () -> () +func.func @update_tracked_op_mapping_region() { + "test.container"() ({ + %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32) + }) : () -> () + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns [] to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op // Add an attribute to %1, which is now mapped to a new op. transform.annotate %1 "annotated" : !transform.any_op } @@ -36,7 +60,7 @@ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{tracking listener failed to find replacement op}} - transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op // %1 must be used in some way. If no replacement payload op could be found, // an error is thrown only if the handle is not dead. transform.annotate %1 "annotated" : !transform.any_op @@ -60,7 +84,7 @@ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // No error because %1 is dead. - transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op } // ----- @@ -80,7 +104,7 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {fail_on_payload_replacement_not_found = false}: !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {} {fail_on_payload_replacement_not_found = false}: !transform.any_op transform.annotate %1 "annotated" : !transform.any_op } @@ -95,8 +119,8 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 : !transform.any_op +%0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op } // ----- @@ -118,7 +142,7 @@ %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op - transform.apply_patterns ["transform.test"] to %0 : !transform.any_op + transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op } @@ -138,6 +162,8 @@ ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op + transform.apply_patterns [] to %1 { + transform.apply_patterns.canonicalization + } : !transform.any_op transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -776,7 +776,14 @@ return success(); } }; +} // namespace +void mlir::test::ApplyTestPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} + +namespace { void populateTestPatterns(RewritePatternSet &patterns) { patterns.insert(patterns.getContext()); } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -59,7 +59,6 @@ let results = (outs TransformValueHandleTypeInterface:$out); let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; let cppNamespace = "::mlir::test"; - } def TestProduceValueHandleToResult @@ -478,4 +477,13 @@ let cppNamespace = "::mlir::test"; } +def ApplyTestPatternsOp + : Op]> { + let arguments = (ins); + let results = (outs); + let assemblyFormat = "attr-dict"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD