diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -68,6 +68,13 @@ /// These are represented with OpaqueType. bool allowsUnknownTypes() const { return unknownTypesAllowed; } + /// Register dialect-wide canonicalization patterns. This method should only + /// be used to register canonicalization patterns that do not conceptually + /// belong to any single operation in the dialect. (In that case, use the op's + /// canonicalizer.) E.g., canonicalization patterns for op interfaces should + /// be registered here. + virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {} + /// Registered hook to materialize a single constant operation from a given /// attribute value with the desired resultant type. This method should use /// the provided builder to create the operation without changing the diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -275,6 +275,9 @@ // If this dialect overrides the hook for op interface fallback. bit hasOperationInterfaceFallback = 0; + + // If this dialect overrides the hook for canonicalization patterns. + bit hasCanonicalizer = 0; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h --- a/mlir/include/mlir/TableGen/Dialect.h +++ b/mlir/include/mlir/TableGen/Dialect.h @@ -51,6 +51,9 @@ // Returns the dialects extra class declaration code. llvm::Optional getExtraClassDeclaration() const; + /// Returns true if this dialect has a canonicalizer. + bool hasCanonicalizer() const; + // Returns true if this dialect has a constant materializer. bool hasConstantMaterializer() const; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -61,6 +61,10 @@ return value.empty() ? llvm::Optional() : value; } +bool Dialect::hasCanonicalizer() const { + return def->getValueAsBit("hasCanonicalizer"); +} + bool Dialect::hasConstantMaterializer() const { return def->getValueAsBit("hasConstantMaterializer"); } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -35,6 +35,8 @@ /// execution. LogicalResult initialize(MLIRContext *context) override { RewritePatternSet owningPatterns(context); + for (auto *dialect : context->getLoadedDialects()) + dialect->getCanonicalizationPatterns(owningPatterns); for (auto *op : context->getRegisteredOperations()) op->getCanonicalizationPatterns(owningPatterns, context); patterns = std::move(owningPatterns); diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -104,4 +104,12 @@ // CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG_0]], %[[C2]] // CHECK: return %[[D0]], %[[C5]], %[[C2]], %[[C3]], %[[D1]] return %1, %2, %3, %4, %5 : index, index, index, index, index -} \ No newline at end of file +} + +// CHECK-LABEL: test_dialect_canonicalizer +func @test_dialect_canonicalizer() -> (i32) { + %0 = "test.dialect_canonicalizable"() : () -> (i32) + // CHECK: %[[CST:.*]] = constant 42 : i32 + // CHECK: return %[[CST]] + return %0 : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -287,6 +287,23 @@ return targetOperandsMutable(); } +//===----------------------------------------------------------------------===// +// TestDialectCanonicalizerOp +//===----------------------------------------------------------------------===// + +static LogicalResult +dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, + PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(op, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(42)); + return success(); +} + +void TestDialect::getCanonicalizationPatterns( + RewritePatternSet &results) const { + results.add(&dialectCanonicalizationPattern); +} + //===----------------------------------------------------------------------===// // TestFoldToCallOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -25,6 +25,7 @@ def Test_Dialect : Dialect { let name = "test"; let cppNamespace = "::mlir::test"; + let hasCanonicalizer = 1; let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; @@ -966,6 +967,11 @@ let hasFolder = 1; } +def TestDialectCanonicalizerOp : TEST_Op<"dialect_canonicalizable"> { + let arguments = (ins); + let results = (outs I32); +} + //===----------------------------------------------------------------------===// // Test Patterns (Symbol Binding) diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -107,6 +107,13 @@ ::mlir::DialectAsmPrinter &os) const override; )"; +/// The code block for the canonicalization pattern registration hook. +static const char *const canonicalizerDecl = R"( + /// Register canonicalization patterns. + void getCanonicalizationPatterns( + ::mlir::RewritePatternSet &results) const override; +)"; + /// The code block for the constant materializer hook. static const char *const constantMaterializerDecl = R"( /// Materialize a single constant operation from a given attribute value with @@ -180,6 +187,8 @@ os << typeParserDecl; // Add the decls for the various features of the dialect. + if (dialect.hasCanonicalizer()) + os << canonicalizerDecl; if (dialect.hasConstantMaterializer()) os << constantMaterializerDecl; if (dialect.hasOperationAttrVerify())