diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -59,10 +59,9 @@ /// terms of shapes of its input operands. void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); -/// Appends patterns for simplifying extract_strided_metadata(other_op) into -/// easier to analyze constructs. -void populateSimplifyExtractStridedMetadataOpPatterns( - RewritePatternSet &patterns); +/// Appends patterns for expanding memref operations that modify the metadata +/// (sizes, offset, strides) of a memref into easier to analyze constructs. +void populateExpandStridedMetadataPatterns(RewritePatternSet &patterns); /// Appends patterns for emulating wide integer memref operations with ops over /// narrower integer types. @@ -135,10 +134,9 @@ /// in terms of shapes of its input operands. std::unique_ptr createResolveShapedTypeResultDimsPass(); -/// Creates an operation pass to simplify -/// `extract_strided_metadata(other_op(memref))` into -/// `extract_strided_metadata(memref)`. -std::unique_ptr createSimplifyExtractStridedMetadataPass(); +/// Creates an operation pass to expand some memref operation into +/// easier to reason about operations. +std::unique_ptr createExpandStridedMetadataPass(); //===----------------------------------------------------------------------===// // Registration diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -189,15 +189,16 @@ ]; } -def SimplifyExtractStridedMetadata : Pass<"simplify-extract-strided-metadata"> { - let summary = "Simplify extract_strided_metadata ops"; +def ExpandStridedMetadata : Pass<"expand-strided-metadata"> { + let summary = "Expand memref operations into easier to analyze constructs"; let description = [{ - The pass simplifies extract_strided_metadata(other_op(memref)) to - extract_strided_metadata(memref) when it is possible to model the effect - of other_op directly with affine maps applied to the result of - extract_strided_metadata. + The pass expands memref operations that modify the metadata of a memref + (sizes, offset, strides) into a sequence of easier to analyze constructs. + In particular, this pass transforms operations into explicit sequence of + operations that model the effect of this operation on the different metadata. + This pass uses affine constructs to materialize these effects. }]; - let constructor = "mlir::memref::createSimplifyExtractStridedMetadataPass()"; + let constructor = "mlir::memref::createExpandStridedMetadataPass()"; let dependentDialects = [ "AffineDialect", "memref::MemRefDialect" ]; diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,12 +1,12 @@ add_mlir_dialect_library(MLIRMemRefTransforms ComposeSubView.cpp ExpandOps.cpp + ExpandStridedMetadata.cpp EmulateWideInt.cpp FoldMemRefAliasOps.cpp MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp - SimplifyExtractStridedMetadata.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp rename from mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp rename to mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1,4 +1,4 @@ -//===- SimplifyExtractStridedMetadata.cpp - Simplify this operation -------===// +//===- ExpandStridedMetadata.cpp - Simplify this operation -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,11 @@ // //===----------------------------------------------------------------------===// // -/// This pass simplifies extract_strided_metadata(other_op(memref) to -/// extract_strided_metadata(memref) when it is possible to express the effect -// of other_op using affine apply on the results of -// extract_strided_metadata(memref). +/// The pass expands memref operations that modify the metadata of a memref +/// (sizes, offset, strides) into a sequence of easier to analyze constructs. +/// In particular, this pass transforms operations into explicit sequence of +/// operations that model the effect of this operation on the different +/// metadata. This pass uses affine constructs to materialize these effects. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -23,7 +24,7 @@ namespace mlir { namespace memref { -#define GEN_PASS_DEF_SIMPLIFYEXTRACTSTRIDEDMETADATA +#define GEN_PASS_DEF_EXPANDSTRIDEDMETADATA #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir @@ -736,7 +737,7 @@ }; } // namespace -void memref::populateSimplifyExtractStridedMetadataOpPatterns( +void memref::populateExpandStridedMetadataPatterns( RewritePatternSet &patterns) { patterns.add { +struct ExpandStridedMetadataPass final + : public memref::impl::ExpandStridedMetadataBase< + ExpandStridedMetadataPass> { void runOnOperation() override; }; } // namespace -void SimplifyExtractStridedMetadataPass::runOnOperation() { +void ExpandStridedMetadataPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - memref::populateSimplifyExtractStridedMetadataOpPatterns(patterns); + memref::populateExpandStridedMetadataPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), std::move(patterns)); } -std::unique_ptr memref::createSimplifyExtractStridedMetadataPass() { - return std::make_unique(); +std::unique_ptr memref::createExpandStridedMetadataPass() { + return std::make_unique(); } diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir rename from mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir rename to mlir/test/Dialect/MemRef/expand-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s +// RUN: mlir-opt --expand-strided-metadata -split-input-file %s -o - | FileCheck %s // CHECK-LABEL: func @extract_strided_metadata_constants // CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32, strided<[4, 1], offset: 2>>)