diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -35,6 +35,7 @@ let dependentDialects = [ "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" ]; + let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ /// Attribute name used to to memoize indexing maps for named ops. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -178,7 +178,6 @@ }]; let hasFolder = 1; - let hasCanonicalizer = 1; let skipDefaultBuilders = 1; } @@ -230,7 +229,6 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; - let hasCanonicalizer = 1; } /// A base class for pooling operation such as conv. The arguments must contain @@ -427,7 +425,6 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; - let hasCanonicalizer = 1; } // Only support buffer semantics. @@ -490,7 +487,6 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; - let hasCanonicalizer = 1; } def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> { @@ -673,7 +669,6 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; - let hasCanonicalizer = 1; } /// GenericOp with Indexing (i.e. multi-for style in which the region is passed diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2787,11 +2787,6 @@ DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp) DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp) -namespace { -struct EraseDeadLinalgOp; -struct FoldTensorCastOp; -} // namespace - #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc" #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" @@ -3374,25 +3369,29 @@ }; } // namespace -#define CANONICALIZERS_AND_FOLDERS(XXX) \ - void XXX::getCanonicalizationPatterns(RewritePatternSet &results, \ - MLIRContext *context) { \ - results.add(context); \ - } \ - \ +#define LINALGOP_FOLDERS(XXX) \ LogicalResult XXX::fold(ArrayRef, \ SmallVectorImpl &) { \ return foldMemRefCast(*this); \ } -CANONICALIZERS_AND_FOLDERS(ConvOp) -CANONICALIZERS_AND_FOLDERS(PoolingMaxOp) -CANONICALIZERS_AND_FOLDERS(PoolingMinOp) -CANONICALIZERS_AND_FOLDERS(PoolingSumOp) -CANONICALIZERS_AND_FOLDERS(CopyOp) -CANONICALIZERS_AND_FOLDERS(FillOp) -CANONICALIZERS_AND_FOLDERS(GenericOp) +LINALGOP_FOLDERS(ConvOp) +LINALGOP_FOLDERS(PoolingMaxOp) +LINALGOP_FOLDERS(PoolingMinOp) +LINALGOP_FOLDERS(PoolingSumOp) +LINALGOP_FOLDERS(CopyOp) +LINALGOP_FOLDERS(FillOp) +LINALGOP_FOLDERS(GenericOp) // All named ops canonicalizers and folders are auto-generated in the // .cpp.inc. + +//===----------------------------------------------------------------------===// +// LinalgDialect +//===----------------------------------------------------------------------===// + +void LinalgDialect::getCanonicalizationPatterns( + RewritePatternSet &results) const { + results.add(getContext()); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1405,6 +1405,8 @@ IndexedGenericOp::getCanonicalizationPatterns(patterns, context); TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); + context->getLoadedDialect()->getCanonicalizationPatterns( + patterns); } void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -414,6 +414,7 @@ memref::SubViewOp::getCanonicalizationPatterns(patterns, ctx); tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); memref::ViewOp::getCanonicalizationPatterns(patterns, ctx); + ctx->getLoadedDialect()->getCanonicalizationPatterns(patterns); CanonicalizationPatternList< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1959,7 +1959,6 @@ return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); }]; let hasFolder = 1; - let hasCanonicalizer = 1; let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. @@ -2094,13 +2093,7 @@ void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os, StringRef cppOpName) { - const char *canonicalizersAndFoldersFmt = R"FMT( - void {0}::getCanonicalizationPatterns( - RewritePatternSet &results, - MLIRContext *context) {{ - results.add(context); - results.add(context); - } + const char *foldersFmt = R"FMT( LogicalResult {0}::fold(ArrayRef, SmallVectorImpl &) {{ return foldMemRefCast(*this); @@ -2112,7 +2105,7 @@ getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers, outputBuffers); })FMT"; - os << llvm::formatv(canonicalizersAndFoldersFmt, cppOpName); + os << llvm::formatv(foldersFmt, cppOpName); } // Prints methods for querying whether the current named op has attributes that diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -503,7 +503,6 @@ return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); }]; let hasFolder = 1; - let hasCanonicalizer = 1; let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. @@ -535,16 +534,10 @@ } )FMT"; -// Implementations of getCanonicalizationPatterns, fold and getEffects. +// Implementations of fold and getEffects. // Parameters: // {0}: Class name -const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT( -void {0}::getCanonicalizationPatterns( - RewritePatternSet &results, - MLIRContext *context) {{ - results.add(context); - results.add(context); -} +const char structuredOpFoldersFormat[] = R"FMT( LogicalResult {0}::fold(ArrayRef, SmallVectorImpl &) {{ return foldMemRefCast(*this); @@ -880,7 +873,7 @@ } // Canonicalizers and folders. - os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className); + os << llvm::formatv(structuredOpFoldersFormat, className); return success(); }