diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -575,6 +575,26 @@ return success(); } }; + +/// Rewrite memref.extract_aligned_pointer_as_index of a ViewLikeOp to the +/// source of the ViewLikeOp. +class RewriteExtractAlignedPointerAsIndexOfViewLikeOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp, + PatternRewriter &rewriter) const override { + auto viewLikeOp = + extractOp.getSource().getDefiningOp(); + if (!viewLikeOp) + return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); + rewriter.updateRootInPlace(extractOp, [&]() { + extractOp.sourceMutable().assign(viewLikeOp.getViewSource()); + }); + return success(); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( @@ -582,7 +602,8 @@ patterns.add, - ExtractStridedMetadataOpAllocFolder>( + ExtractStridedMetadataOpAllocFolder, + RewriteExtractAlignedPointerAsIndexOfViewLikeOp>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -750,3 +750,15 @@ return %base, %offset, %size, %stride : memref, index, index, index } + +// ----- + +// CHECK-LABEL: extract_aligned_pointer_as_index +// CHECK-SAME: (%[[ARG0:.*]]: memref +func.func @extract_aligned_pointer_as_index(%arg0: memref) -> index { + // CHECK-NOT: memref.subview + // CHECK: memref.extract_aligned_pointer_as_index %[[ARG0]] + %c = memref.subview %arg0[] [] [] : memref to memref + %r = memref.extract_aligned_pointer_as_index %arg0: memref -> index + return %r : index +}