diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -43,7 +43,7 @@ /// \endverbatim /// /// In other words, get rid of the subview in that expression and canonicalize -/// on its effects on the offset, the sizes, and the strides using affine apply. +/// on its effects on the offset, the sizes, and the strides using affine.apply. struct ExtractStridedMetadataOpSubviewFolder : public OpRewritePattern { public: @@ -166,11 +166,275 @@ return success(); } }; + +/// Compute the expanded sizes of the given \p expandShape for the +/// \p groupId-th reassociation group. +/// \p origSizes hold the sizes of the source shape as values. +/// This is used to compute the new sizes in cases of dynamic shapes. +/// +/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() +/// +/// TODO: Move this utility function directly within ExpandShapeOp. For now, +/// this is not possible because this function uses the Affine dialect and the +/// MemRef dialect cannot depend on the Affine dialect. +static SmallVector +getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder, + ArrayRef origSizes, unsigned groupId) { + SmallVector reassocGroup = + expandShape.getReassociationIndices()[groupId]; + assert(!reassocGroup.empty() && + "Reassociation group should have at least one dimension"); + + unsigned groupSize = reassocGroup.size(); + SmallVector expandedSizes(groupSize); + + uint64_t productOfAllStaticSizes = 1; + Optional dynSizeIdx; + MemRefType expandShapeType = expandShape.getResultType(); + + // Fill up all the statically known sizes. + for (unsigned i = 0; i < groupSize; ++i) { + uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); + if (ShapedType::isDynamic(dimSize)) { + assert(!dynSizeIdx && "There must be at most one dynamic size per group"); + dynSizeIdx = i; + continue; + } + productOfAllStaticSizes *= dimSize; + expandedSizes[i] = builder.getIndexAttr(dimSize); + } + + // Compute the dynamic size using the original size and all the other known + // static sizes: + // expandSize = origSize / productOfAllStaticSizes. + if (dynSizeIdx) { + AffineExpr s0 = builder.getAffineSymbolExpr(0); + expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( + builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes), + origSizes[groupId]); + } + + return expandedSizes; +} + +/// Compute the expanded strides of the given \p expandShape for the +/// \p groupId-th reassociation group. +/// \p origStrides and \p origSizes hold respectively the strides and sizes +/// of the source shape as values. +/// This is used to compute the strides in cases of dynamic shapes and/or +/// dynamic stride for this reassociation group. +/// +/// \post result.size() == expandShape.getReassociationIndices()[groupId].size() +/// +/// TODO: Move this utility function directly within ExpandShapeOp. For now, +/// this is not possible because this function uses the Affine dialect and the +/// MemRef dialect cannot depend on the Affine dialect. +SmallVector getExpandedStrides(memref::ExpandShapeOp expandShape, + OpBuilder &builder, + ArrayRef origSizes, + ArrayRef origStrides, + unsigned groupId) { + SmallVector reassocGroup = + expandShape.getReassociationIndices()[groupId]; + assert(!reassocGroup.empty() && + "Reassociation group should have at least one dimension"); + + unsigned groupSize = reassocGroup.size(); + MemRefType expandShapeType = expandShape.getResultType(); + + Optional dynSizeIdx; + + // Fill up the expanded strides, with the information we can deduce from the + // resulting shape. + uint64_t currentStride = 1; + SmallVector expandedStrides(groupSize); + for (int i = groupSize - 1; i >= 0; --i) { + expandedStrides[i] = builder.getIndexAttr(currentStride); + uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]); + if (ShapedType::isDynamic(dimSize)) { + assert(!dynSizeIdx && "There must be at most one dynamic size per group"); + dynSizeIdx = i; + continue; + } + + currentStride *= dimSize; + } + + // Collect the statically known information about the original stride. + Value source = expandShape.getSrc(); + auto sourceType = source.getType().cast(); + SmallVector strides; + int64_t offset; + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, strides, offset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid expand_shape"); + + OpFoldResult origStride = + ShapedType::isDynamicStrideOrOffset(strides[groupId]) + ? origStrides[groupId] + : builder.getIndexAttr(strides[groupId]); + + // Apply the original stride to all the strides. + int64_t doneStrideIdx = 0; + // If we saw a dynamic dimension, we need to fix-up all the strides up to + // that dimension with the dynamic size. + if (dynSizeIdx) { + int64_t productOfAllStaticSizes = currentStride; + assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) && + "We shouldn't be able to change dynamicity"); + OpFoldResult origSize = origSizes[groupId]; + + AffineExpr s0 = builder.getAffineSymbolExpr(0); + AffineExpr s1 = builder.getAffineSymbolExpr(1); + for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) { + int64_t baseExpandedStride = expandedStrides[doneStrideIdx] + .get() + .cast() + .getInt(); + expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( + builder, expandShape.getLoc(), + (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1, + {origSize, origStride}); + } + } + + // Now apply the origStride to the remaining dimensions. + AffineExpr s0 = builder.getAffineSymbolExpr(0); + for (; doneStrideIdx < groupSize; ++doneStrideIdx) { + int64_t baseExpandedStride = expandedStrides[doneStrideIdx] + .get() + .cast() + .getInt(); + expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply( + builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride}); + } + + return expandedStrides; +} + +/// Replace `baseBuffer, offset, sizes, strides = +/// extract_strided_metadata(expand_shape(memref))` +/// With +/// +/// \verbatim +/// baseBuffer, offset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// sizes#reassIdx = +/// baseSizes#reassDim / product(expandShapeSizes#j, +/// for j in group excluding reassIdx) +/// strides#reassIdx = +/// baseStrides#reassDim * product(expandShapeSizes#j, for j in +/// reassIdx+1..reassIdx+group.size-1) +/// \endverbatim +/// +/// Where reassIdx is a reassociation index for the group at reassDim +/// and expandShapeSizes#j is either: +/// - The constant size at dimension j, derived directly from the result type of +/// the expand_shape op, or +/// - An affine expression: baseSizes#reassDim / product of all constant sizes +/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic +/// element.) +/// +/// Notice that `baseBuffer` and `offset` are unchanged. +/// +/// In other words, get rid of the expand_shape in that expression and +/// materialize its effects on the sizes and the strides using affine apply. +struct ExtractStridedMetadataOpExpandShapeFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto expandShape = op.getSource().getDefiningOp(); + if (!expandShape) + return failure(); + + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(expand_shape(memref)). + Location origLoc = op.getLoc(); + IndexType indexType = rewriter.getIndexType(); + Value source = expandShape.getSrc(); + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + SmallVector sizeStrideTypes(sourceRank, indexType); + + auto newExtractStridedMetadata = + rewriter.create( + origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, + sizeStrideTypes, source); + + // Collect statically known information. + SmallVector strides; + int64_t offset; + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, strides, offset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid expand_shape"); + MemRefType expandShapeType = expandShape.getResultType(); + unsigned expandShapeRank = expandShapeType.getRank(); + + // The result value will start with the base_buffer and offset. + unsigned baseIdxInResult = 2; + SmallVector results(baseIdxInResult + expandShapeRank * 2); + results[0] = newExtractStridedMetadata.getBaseBuffer(); + results[1] = ShapedType::isDynamicStrideOrOffset(offset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(offset); + + // Get the special case of 0-D out of the way. + if (sourceRank == 0) { + Value constantOne = getValueOrCreateConstantIndexOp( + rewriter, origLoc, rewriter.getIndexAttr(1)); + SmallVector resultValues(baseIdxInResult + expandShapeRank * 2, + constantOne); + for (unsigned i = 0; i < baseIdxInResult; ++i) + resultValues[i] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]); + rewriter.replaceOp(op, resultValues); + return success(); + } + + // Compute the expanded strides and sizes from the base strides and sizes. + SmallVector origSizes = + getAsOpFoldResult(newExtractStridedMetadata.getSizes()); + SmallVector origStrides = + getAsOpFoldResult(newExtractStridedMetadata.getStrides()); + unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size(); + for (; idx != endIdx; ++idx) { + SmallVector expandedSizes = + getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx); + SmallVector expandedStrides = getExpandedStrides( + expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx); + + unsigned groupSize = expandShape.getReassociationIndices()[idx].size(); + const unsigned sizeStartIdx = baseIdxInResult; + const unsigned strideStartIdx = sizeStartIdx + expandShapeRank; + for (unsigned i = 0; i < groupSize; ++i) { + results[sizeStartIdx + i] = expandedSizes[i]; + results[strideStartIdx + i] = expandedStrides[i]; + } + baseIdxInResult += groupSize; + } + assert(idx == sourceRank && + "We should have visited all the input dimensions"); + assert(baseIdxInResult == expandShapeRank + 2 && + "We should have populated all the values"); + rewriter.replaceOp( + op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results)); + return success(); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -280,3 +280,231 @@ return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : memref, index, index, index, index, index, index, index } + +// ----- + +// Check that we properly simplify extract_strided_metadata of expand_shape +// into: +// baseBuffer, baseOffset, baseSizes, baseStrides = +// extract_strided_metadata(memref) +// sizes#reassIdx = +// baseSizes#reassDim / product(expandShapeSizes#j, +// for j in group excluding reassIdx) +// strides#reassIdx = +// baseStrides#reassDim * product(expandShapeSizes#j, for j in +// reassIdx+1..reassIdx+group.size) +// +// Here we have: +// For the group applying to dim0: +// size 0 = 3 +// size 1 = 5 +// size 2 = 2 +// stride 0 = baseStrides#0 * 5 * 2 +// = 4 * 5 * 2 +// = 40 +// stride 1 = baseStrides#0 * 2 +// = 4 * 2 +// = 8 +// stride 2 = baseStrides#0 +// = 4 +// +// For the group applying to dim1: +// size 3 = 2 +// size 4 = 2 +// stride 3 = baseStrides#1 * 2 +// = 1 * 2 +// = 2 +// stride 4 = baseStrides#1 +// = 1 +// +// Base and offset are unchanged. +// +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static +// CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>) +// +// CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref, index, index, index, index, index +// +// CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_static( + %arg : memref<30x4xi16>) + -> (memref, index, + index, index, index, index, index, + index, index, index, index, index) { + + %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] : + memref<30x4xi16> into memref<3x5x2x2x2xi16> + + %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : + memref<3x5x2x2x2xi16> + -> memref, index, + index, index, index, index, index, + index, index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : + memref, index, + index, index, index, index, index, + index, index, index, index, index +} + +// ----- + +// Check that we properly simplify extract_strided_metadata of expand_shape +// when dynamic sizes, strides, and offsets are involved. +// See extract_strided_metadata_of_expand_shape_all_static for an explanation +// of the expansion. +// +// One of the important characteristic of this test is that the dynamic +// dimensions produced by the expand_shape appear both in the first dimension +// (for group 1) and the non-first dimension (second dimension for group 2.) +// The idea is to make sure that: +// 1. We properly account for dynamic shapes even when the strides are not +// affected by them. (When the dynamic dimension is the first one.) +// 2. We properly compute the strides affected by dynamic shapes. (When the +// dynamic dimension is not the first one.) +// +// Here we have: +// For the group applying to dim0: +// size 0 = baseSizes#0 / (all static sizes in that group) +// = baseSizes#0 / (7 * 8 * 9) +// = baseSizes#0 / 504 +// size 1 = 7 +// size 2 = 8 +// size 3 = 9 +// stride 0 = baseStrides#0 * 7 * 8 * 9 +// = baseStrides#0 * 504 +// stride 1 = baseStrides#0 * 8 * 9 +// = baseStrides#0 * 72 +// stride 2 = baseStrides#0 * 9 +// stride 3 = baseStrides#0 +// +// For the group applying to dim1: +// size 4 = 10 +// size 5 = 2 +// size 6 = baseSizes#1 / (all static sizes in that group) +// = baseSizes#1 / (10 * 2 * 3) +// = baseSizes#1 / 60 +// size 7 = 3 +// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 +// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 +// = baseStrides#1 * (baseSizes#1 / 60) * 6 +// and since we know that baseSizes#1 is a multiple of 60: +// = baseStrides#1 * (baseSizes#1 / 10) +// stride 5 = baseStrides#1 * size 6 * size 7 +// = baseStrides#1 * (baseSizes#1 / 60) * 3 +// = baseStrides#1 * (baseSizes#1 / 20) +// stride 6 = baseStrides#1 * size 7 +// = baseStrides#1 * 3 +// stride 7 = baseStrides#1 +// +// Base and offset are unchanged. +// +// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> +// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> +// +// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> +// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> +// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> +// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> +// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> +// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic +// CHECK-SAME: (%[[ARG:.*]]: memref> -> memref, index, index, index, index, index +// +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] +// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] + +// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref, index, index, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_dynamic( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> (memref, index, + index, index, index, index, index, index, index, index, + index, index, index, index, index, index, index, index) { + + %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] : + memref> into + memref> + + %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, + index, index, index, index, index, index, index, index, + index, index, index, index, index, index, index, index + + return %base_buffer, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 : + memref, index, + index, index, index, index, index, index, index, index, + index, index, index, index, index, index, index, index +} + + +// ----- + +// Check that we properly handle extract_strided_metadata of expand_shape for +// 0-D input. +// The 0-D case is pretty boring: +// All expanded sizes are 1, likewise for the strides, and we keep the +// original base and offset. +// We have still a test for it, because since the input reassociation map +// of the expand_shape is empty, the handling of such shape hits a corner +// case. +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank +// CHECK-SAME: (%[[ARG:.*]]: memref>) +// +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref> -> memref, index +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank( + %arg : memref>) + -> (memref, index, + index, index, index, index, index, + index, index, index, index, index) { + + %expand_shape = memref.expand_shape %arg[] : + memref> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> + + %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : + memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> + -> memref, index, + index, index, index, index, index, + index, index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : + memref, index, + index, index, index, index, index, + index, index, index, index, index +}