diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -660,6 +660,34 @@ return success(); } }; + +/// Replace `base, offset = +/// extract_strided_metadata(extract_strided_metadata(src)#0)` +/// With +/// ``` +/// base, ... = extract_strided_metadata(src) +/// offset = 0 +/// ``` +class ExtractStridedMetadataOpExtractStridedMetadataFolder + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + PatternRewriter &rewriter) const override { + auto sourceExtractStridedMetadataOp = + extractStridedMetadataOp.getSource() + .getDefiningOp(); + if (!sourceExtractStridedMetadataOp) + return failure(); + Location loc = extractStridedMetadataOp.getLoc(); + rewriter.replaceOp(extractStridedMetadataOp, + {sourceExtractStridedMetadataOp.getBaseBuffer(), + getValueOrCreateConstantIndexOp( + rewriter, loc, rewriter.getIndexAttr(0))}); + return success(); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( @@ -672,7 +700,8 @@ memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>, ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, - RewriteExtractAlignedPointerAsIndexOfViewLikeOp>( + RewriteExtractAlignedPointerAsIndexOfViewLikeOp, + ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -844,3 +844,29 @@ return %base, %offset : memref, index } + +// ----- + +// Check that we simplify extract_strided_metadata of +// extract_strided_metadata. +// +// CHECK-LABEL: func @extract_strided_metadata_of_extract_strided_metadata( +// CHECK-SAME: %[[ARG:.*]]: memref) +// +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C0]] +func.func @extract_strided_metadata_of_extract_strided_metadata(%arg : memref) + -> (memref, index) { + + %base, %offset = + memref.extract_strided_metadata %arg:memref + -> memref, index + %base2, %offset2 = + memref.extract_strided_metadata %base:memref + -> memref, index + + return %base2, %offset2 : + memref, index +}