diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -455,20 +455,15 @@ ReassociativeReshapeLikeOp, OpBuilder &, ArrayRef /*origSizes*/, ArrayRef /*origStrides*/, unsigned /*groupId*/)> -struct ExtractStridedMetadataOpReshapeFolder - : public OpRewritePattern { +struct ReshapeFolder : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, PatternRewriter &rewriter) const override { - auto reshape = op.getSource().getDefiningOp(); - if (!reshape) - return failure(); - // Build a plain extract_strided_metadata(memref) from // extract_strided_metadata(reassociative_reshape_like(memref)). - Location origLoc = op.getLoc(); + Location origLoc = reshape.getLoc(); Value source = reshape.getSrc(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); @@ -487,27 +482,26 @@ MemRefType reshapeType = reshape.getResultType(); unsigned reshapeRank = reshapeType.getRank(); - // The result value will start with the base_buffer and offset. - unsigned baseIdxInResult = 2; - SmallVector results(baseIdxInResult + reshapeRank * 2); - results[0] = newExtractStridedMetadata.getBaseBuffer(); - results[1] = ShapedType::isDynamicStrideOrOffset(offset) - ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) - : rewriter.getIndexAttr(offset); + OpFoldResult offsetOfr = + ShapedType::isDynamicStrideOrOffset(offset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(offset); // Get the special case of 0-D out of the way. if (sourceRank == 0) { - Value constantOne = getValueOrCreateConstantIndexOp( - rewriter, origLoc, rewriter.getIndexAttr(1)); - SmallVector resultValues(baseIdxInResult + reshapeRank * 2, - constantOne); - for (unsigned i = 0; i < baseIdxInResult; ++i) - resultValues[i] = - getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]); - rewriter.replaceOp(op, resultValues); + SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); + auto memrefDesc = rewriter.create( + origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), + offsetOfr, /*sizes=*/ones, /*strides=*/ones); + rewriter.replaceOp(reshape, memrefDesc.getResult()); return success(); } + SmallVector finalSizes; + finalSizes.reserve(reshapeRank); + SmallVector finalStrides; + finalStrides.reserve(reshapeRank); + // Compute the reshaped strides and sizes from the base strides and sizes. SmallVector origSizes = getAsOpFoldResult(newExtractStridedMetadata.getSizes()); @@ -521,21 +515,20 @@ reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); unsigned groupSize = reshapedSizes.size(); - const unsigned sizeStartIdx = baseIdxInResult; - const unsigned strideStartIdx = sizeStartIdx + reshapeRank; for (unsigned i = 0; i < groupSize; ++i) { - results[sizeStartIdx + i] = reshapedSizes[i]; - results[strideStartIdx + i] = reshapedStrides[i]; + finalSizes.push_back(reshapedSizes[i]); + finalStrides.push_back(reshapedStrides[i]); } - baseIdxInResult += groupSize; } assert(((isa(reshape) && idx == sourceRank) || (isa(reshape) && idx == reshapeRank)) && "We should have visited all the input dimensions"); - assert(baseIdxInResult == reshapeRank + 2 && + assert(finalSizes.size() == reshapeRank && "We should have populated all the values"); - rewriter.replaceOp( - op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results)); + auto memrefDesc = rewriter.create( + origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), + offsetOfr, finalSizes, finalStrides); + rewriter.replaceOp(reshape, memrefDesc.getResult()); return success(); } }; @@ -745,18 +738,17 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns - .add, - ExtractStridedMetadataOpReshapeFolder< - memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>, - ExtractStridedMetadataOpAllocFolder, - ExtractStridedMetadataOpAllocFolder, - RewriteExtractAlignedPointerAsIndexOfViewLikeOp, - ExtractStridedMetadataOpReinterpretCastFolder, - ExtractStridedMetadataOpExtractStridedMetadataFolder>( - patterns.getContext()); + patterns.add, + ReshapeFolder, + ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpAllocFolder, + RewriteExtractAlignedPointerAsIndexOfViewLikeOp, + ExtractStridedMetadataOpReinterpretCastFolder, + ExtractStridedMetadataOpExtractStridedMetadataFolder>( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -352,6 +352,88 @@ // ----- +// Check that we properly simplify expand_shape into: +// reinterpret_cast(extract_strided_metadata) + +// +// Here we have: +// For the group applying to dim0: +// size 0 = baseSizes#0 / (all static sizes in that group) +// = baseSizes#0 / (7 * 8 * 9) +// = baseSizes#0 / 504 +// size 1 = 7 +// size 2 = 8 +// size 3 = 9 +// stride 0 = baseStrides#0 * 7 * 8 * 9 +// = baseStrides#0 * 504 +// stride 1 = baseStrides#0 * 8 * 9 +// = baseStrides#0 * 72 +// stride 2 = baseStrides#0 * 9 +// stride 3 = baseStrides#0 +// +// For the group applying to dim1: +// size 4 = 10 +// size 5 = 2 +// size 6 = baseSizes#1 / (all static sizes in that group) +// = baseSizes#1 / (10 * 2 * 3) +// = baseSizes#1 / 60 +// size 7 = 3 +// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 +// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 +// = baseStrides#1 * (baseSizes#1 / 60) * 6 +// and since we know that baseSizes#1 is a multiple of 60: +// = baseStrides#1 * (baseSizes#1 / 10) +// stride 5 = baseStrides#1 * size 6 * size 7 +// = baseStrides#1 * (baseSizes#1 / 60) * 3 +// = baseStrides#1 * (baseSizes#1 / 20) +// stride 6 = baseStrides#1 * size 7 +// = baseStrides#1 * 3 +// stride 7 = baseStrides#1 +// +// Base and offset are unchanged. +// +// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> +// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> +// +// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> +// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> +// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> +// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> +// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> +// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: func @simplify_expand_shape +// CHECK-SAME: (%[[ARG:.*]]: memref> -> memref, index, index, index, index, index +// +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] +// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] +// +// CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1] +// +// CHECK: return %[[REINTERPRET_CAST]] +func.func @simplify_expand_shape( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> memref> { + + %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] : + memref> into + memref> + + return %subview : + memref> +} + +// ----- + // Check that we properly simplify extract_strided_metadata of expand_shape // into: // baseBuffer, baseOffset, baseSizes, baseStrides = @@ -815,6 +897,43 @@ // ----- +// Check that we simplify collapse_shape into +// reinterpret_cast(extract_strided_metadata) + +// +// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5] +// Size 0 = origSize0 +// Size 1 = origSize1 * origSize2 * origSize3 +// = origSize1 * 4 * origSize3 +// Size 2 = origSize4 * origSize5 +// = 6 * 7 +// = 42 +// Stride 0 = origStride0 +// Stride 1 = origStride3 (orig stride of the inner most dimension) +// = 42 +// Stride 2 = origStride5 +// = 1 +// +// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-LABEL: func @simplify_collapse( +// CHECK-SAME: %[[ARG:.*]]: memref) +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref +// +// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1] +func.func @simplify_collapse(%arg : memref) + -> memref { + + %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] : + memref into memref + + return %collapsed_view : memref + +} + +// ----- + // Check that we simplify extract_strided_metadata of collapse_shape. // // We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]