diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -/// This header declares functions that assit transformations in the MemRef +/// This header declares functions that assist transformations in the MemRef /// dialect. // //===----------------------------------------------------------------------===// @@ -44,9 +44,9 @@ /// Appends patterns that resolve `memref.dim` operations with values that are /// defined by operations that implement the -/// `ReifyRankedShapeTypeShapeOpInterface`, in terms of shapes of its input +/// `ReifyRankedShapedTypeOpInterface`, in terms of shapes of its input /// operands. -void populateResolveRankedShapeTypeResultDimsPatterns( +void populateResolveRankedShapedTypeResultDimsPatterns( RewritePatternSet &patterns); /// Appends patterns that resolve `memref.dim` operations with values that are @@ -68,7 +68,7 @@ arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns); -/// Appends type converions for emulating wide integer memref operations with +/// Appends type conversions for emulating wide integer memref operations with /// ops over narrowe integer types. void populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -675,7 +675,7 @@ tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } @@ -689,7 +689,7 @@ linalg::FillOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::populateFoldTensorEmptyPatterns(patterns); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -161,6 +162,23 @@ #define GET_OP_LIST #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" >(); + + addDialectDataInitializer( + [&](transform::PatternRegistry ®istry) { + registry.registerPatterns("memref.expand_ops", + memref::populateExpandOpsPatterns); + registry.registerPatterns("memref.fold_memref_alias_ops", + memref::populateFoldMemRefAliasOpPatterns); + registry.registerPatterns( + "memref.resolve_ranked_shaped_type_result_dims", + memref::populateResolveRankedShapedTypeResultDimsPatterns); + registry.registerPatterns( + "memref.expand_strided_metadata", + memref::populateExpandStridedMetadataPatterns); + registry.registerPatterns( + "memref.extract_address_computations", + memref::populateExtractAddressComputationsPatterns); + }); } }; } // namespace diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -121,7 +121,7 @@ } // namespace -void memref::populateResolveRankedShapeTypeResultDimsPatterns( +void memref::populateResolveRankedShapedTypeResultDimsPatterns( RewritePatternSet &patterns) { patterns.add, DimOfReifyRankedShapedTypeOpInterface>( @@ -138,14 +138,14 @@ void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure();