diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -661,6 +661,69 @@ } }; +/// Replace `base, offset, sizes, strides = +/// extract_strided_metadata( +/// reinterpret_cast(src, srcOffset, srcSizes, srcStrides))` +/// With +/// ``` +/// base, ... = extract_strided_metadata(src) +/// offset = srcOffset +/// sizes = srcSizes +/// strides = srcStrides +/// ``` +/// +/// In other words, consume the `reinterpret_cast` and apply its effects +/// on the offset, sizes, and strides. +class ExtractStridedMetadataOpReinterpretCastFolder + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult + matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, + PatternRewriter &rewriter) const override { + auto reinterpretCastOp = extractStridedMetadataOp.getSource() + .getDefiningOp(); + if (!reinterpretCastOp) + return failure(); + + // Check if the source is suitable for extract_strided_metadata. + if (!reinterpretCastOp.getSource().getType().isa()) + return rewriter.notifyMatchFailure( + reinterpretCastOp, "reinterpret_cast source's type is incompatible"); + + auto memrefType = + reinterpretCastOp.getResult().getType().cast(); + unsigned rank = memrefType.getRank(); + SmallVector results; + results.resize_for_overwrite(rank * 2 + 2); + + Location loc = extractStridedMetadataOp.getLoc(); + auto newExtractStridedMetadata = + rewriter.create( + loc, reinterpretCastOp.getSource()); + + // Register the base_buffer. + results[0] = newExtractStridedMetadata.getBaseBuffer(); + + // Register the new offset. + results[1] = getValueOrCreateConstantIndexOp( + rewriter, loc, reinterpretCastOp.getMixedOffsets()[0]); + + const unsigned sizeStartIdx = 2; + const unsigned strideStartIdx = sizeStartIdx + rank; + + SmallVector sizes = reinterpretCastOp.getMixedSizes(); + SmallVector strides = reinterpretCastOp.getMixedStrides(); + for (unsigned i = 0; i < rank; ++i) { + results[sizeStartIdx + i] = sizes[i]; + results[strideStartIdx + i] = strides[i]; + } + rewriter.replaceOp(extractStridedMetadataOp, + getValueOrCreateConstantIndexOp(rewriter, loc, results)); + return success(); + } +}; + /// Replace `base, offset = /// extract_strided_metadata(extract_strided_metadata(src)#0)` /// With @@ -701,6 +764,7 @@ ExtractStridedMetadataOpAllocFolder, ExtractStridedMetadataOpAllocFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, + ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -870,3 +870,136 @@ return %base2, %offset2 : memref, index } + +// ----- + +// Check that we simplify extract_strided_metadata of reinterpret_cast +// when the source of the reinterpret_cast is compatible with what +// `extract_strided_metadata`s accept. +// +// When we apply the transformation the resulting offset, sizes and strides +// should come straight from the inputs of the reinterpret_cast. +// +// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast +// CHECK-SAME: %[[ARG:.*]]: memref>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index) +// +// CHECK: %[[BASE:.*]], %{{.*}}, %{{.*}}:2, %{{.*}}:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]] +func.func @extract_strided_metadata_of_reinterpret_cast( + %arg : memref>, + %offset: index, + %size0 : index, %size1 : index, + %stride0 : index, %stride1 : index) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.reinterpret_cast %arg to + offset: [%offset], + sizes: [%size0, %size1], + strides: [%stride0, %stride1] : + memref> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} + +// ----- + +// Check that we don't simplify extract_strided_metadata of +// reinterpret_cast when the source of the cast is unranked. +// Unranked memrefs cannot feed into extract_strided_metadata operations. +// Note: Technically we could still fold the sizes and strides. +// +// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_unranked +// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index) +// +// CHECK: %[[CAST:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[DYN_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]]], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]]] +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 +func.func @extract_strided_metadata_of_reinterpret_cast_unranked( + %arg : memref<*xi32>, + %offset: index, + %size0 : index, %size1 : index, + %stride0 : index, %stride1 : index) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.reinterpret_cast %arg to + offset: [%offset], + sizes: [%size0, %size1], + strides: [%stride0, %stride1] : + memref<*xi32> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +} + +// ----- + +// Similar to @extract_strided_metadata_of_reinterpret_cast, just make sure +// we handle 0-D properly. +// +// CHECK-LABEL: func @extract_strided_metadata_of_reinterpret_cast_rank0 +// CHECK-SAME: %[[ARG:.*]]: memref>, %[[DYN_OFFSET:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index) +// +// CHECK: %[[BASE:.*]], %[[BASE_OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]] +func.func @extract_strided_metadata_of_reinterpret_cast_rank0( + %arg : memref>, + %offset: index, + %size0 : index, %size1 : index, + %stride0 : index, %stride1 : index) + -> (memref, index, + index, index, + index, index) { + + %cast = + memref.reinterpret_cast %arg to + offset: [%offset], + sizes: [%size0, %size1], + strides: [%stride0, %stride1] : + memref> to + memref> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref> + -> memref, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref, index, + index, index, + index, index +}