diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -137,7 +137,9 @@ Only patterns that were registered in the transform dialect's `PatternRegistry` are available. Additional patterns can be registered as - part of transform dialect extensions. + part of transform dialect extensions. "canonicalization" is a special set + of patterns that refers to all canonicalization patterns of all loaded + dialects. This transform only reads the target handle and modifies the payload. If a pattern erases or replaces a tracked op, the mapping is updated accordingly. diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -57,6 +57,16 @@ #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); initializeTypes(); + + // Register all canonicalization patterns. + getOrCreateExtraData().registerPatterns( + "canonicalization", [](RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + for (Dialect *dialect : ctx->getLoadedDialects()) + dialect->getCanonicalizationPatterns(patterns); + for (RegisteredOperationName op : ctx->getRegisteredOperations()) + op.getCanonicalizationPatterns(patterns, ctx); + }); } Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -89,6 +89,11 @@ if (op->getName() == defOp->getName()) return defOp; + // Replacing an op with a constant-like equivalent is a common + // canonicalization. + if (defOp->hasTrait()) + return defOp; + values.clear(); // Skip through ops that implement FindPayloadReplacementOpInterface. diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -121,3 +121,23 @@ transform.apply_patterns ["transform.test"] to %0 : !transform.any_op transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @canonicalization( +// CHECK: %[[c5:.*]] = arith.constant 5 : index +// CHECK: return %[[c5]] +func.func @canonicalization(%t: tensor<5xf32>) -> index { + %c0 = arith.constant 0 : index + // expected-remark @below {{op was replaced}} + %dim = tensor.dim %t, %c0 : tensor<5xf32> + return %dim : index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns ["canonicalization"] to %1 : !transform.any_op + transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op +}