diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -173,6 +173,11 @@ /// \p origSizes hold the sizes of the source shape as values. /// This is used to compute the new sizes in cases of dynamic shapes. /// +/// sizes#i = +/// baseSizes#groupId / product(expandShapeSizes#j, +/// for j in group excluding reassIdx#i) +/// Where reassIdx#i is the reassociation index at index i in \p groupId. +/// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, @@ -225,6 +230,18 @@ /// This is used to compute the strides in cases of dynamic shapes and/or /// dynamic stride for this reassociation group. /// +/// strides#i = +/// origStrides#reassDim * product(expandShapeSizes#j, for j in +/// reassIdx#i+1..reassIdx#i+group.size-1) +/// +/// Where reassIdx#i is the reassociation index for at index i in \p groupId +/// and expandShapeSizes#j is either: +/// - The constant size at dimension j, derived directly from the result type of +/// the expand_shape op, or +/// - An affine expression: baseSizes#reassDim / product of all constant sizes +/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic +/// element.) +/// /// \post result.size() == expandShape.getReassociationIndices()[groupId].size() /// /// TODO: Move this utility function directly within ExpandShapeOp. For now, @@ -315,49 +332,162 @@ return expandedStrides; } +/// Produce an OpFoldResult object with \p builder at \p loc representing +/// `prod(valueOrConstant#i, for i in {indices})`, +/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false, +/// values[i] otherwise. +/// +/// \pre for all index in indices: index < values.size() +/// \pre for all index in indices: index < maybeConstants.size() +static OpFoldResult +getProductOfValues(ArrayRef indices, OpBuilder &builder, Location loc, + ArrayRef maybeConstants, + ArrayRef values, + llvm::function_ref isDynamic) { + AffineExpr productOfValues = builder.getAffineConstantExpr(1); + SmallVector inputValues; + unsigned numberOfSymbols = 0; + unsigned groupSize = indices.size(); + for (unsigned i = 0; i < groupSize; ++i) { + productOfValues = + productOfValues * builder.getAffineSymbolExpr(numberOfSymbols++); + unsigned srcIdx = indices[i]; + int64_t maybeConstant = maybeConstants[srcIdx]; + + inputValues.push_back(isDynamic(maybeConstant) + ? values[srcIdx] + : builder.getIndexAttr(maybeConstant)); + } + + return makeComposedFoldedAffineApply(builder, loc, productOfValues, + inputValues); +} + +/// Compute the collapsed size of the given \p collpaseShape for the +/// \p groupId-th reassociation group. +/// \p origSizes hold the sizes of the source shape as values. +/// This is used to compute the new sizes in cases of dynamic shapes. +/// +/// Conceptually this helper function computes: +/// `prod(origSizes#i, for i in {ressociationGroup[groupId]})`. +/// +/// \post result.size() == 1, in other words, each group collapse to one +/// dimension. +/// +/// TODO: Move this utility function directly within CollapseShapeOp. For now, +/// this is not possible because this function uses the Affine dialect and the +/// MemRef dialect cannot depend on the Affine dialect. +static SmallVector +getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder, + ArrayRef origSizes, unsigned groupId) { + SmallVector collapsedSize; + + MemRefType collapseShapeType = collapseShape.getResultType(); + + uint64_t size = collapseShapeType.getDimSize(groupId); + if (!ShapedType::isDynamic(size)) { + collapsedSize.push_back(builder.getIndexAttr(size)); + return collapsedSize; + } + + // We are dealing with a dynamic size. + // Build the affine expr of the product of the original sizes involved in that + // group. + Value source = collapseShape.getSrc(); + auto sourceType = source.getType().cast(); + + SmallVector reassocGroup = + collapseShape.getReassociationIndices()[groupId]; + + collapsedSize.push_back(getProductOfValues( + reassocGroup, builder, collapseShape.getLoc(), sourceType.getShape(), + origSizes, ShapedType::isDynamic)); + + return collapsedSize; +} + +/// Compute the collapsed stride of the given \p collpaseShape for the +/// \p groupId-th reassociation group. +/// \p origStrides and \p origSizes hold respectively the strides and sizes +/// of the source shape as values. +/// This is used to compute the strides in cases of dynamic shapes and/or +/// dynamic stride for this reassociation group. +/// +/// Conceptually this helper function returns the stride of the inner most +/// dimension of that group in the original shape. +/// +/// \post result.size() == 1, in other words, each group collapse to one +/// dimension. +static SmallVector +getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, + ArrayRef origSizes, + ArrayRef origStrides, unsigned groupId) { + SmallVector reassocGroup = + collapseShape.getReassociationIndices()[groupId]; + assert(!reassocGroup.empty() && + "Reassociation group should have at least one dimension"); + + Value source = collapseShape.getSrc(); + auto sourceType = source.getType().cast(); + + SmallVector strides; + int64_t offset; + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, strides, offset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid collapse_shape"); + + SmallVector collapsedStride; + int64_t innerMostDimForGroup = reassocGroup.back(); + int64_t innerMostStrideForGroup = strides[innerMostDimForGroup]; + collapsedStride.push_back( + ShapedType::isDynamicStrideOrOffset(innerMostStrideForGroup) + ? origStrides[innerMostDimForGroup] + : builder.getIndexAttr(innerMostStrideForGroup)); + + return collapsedStride; +} /// Replace `baseBuffer, offset, sizes, strides = -/// extract_strided_metadata(expand_shape(memref))` +/// extract_strided_metadata(reshapeLike(memref))` /// With /// /// \verbatim /// baseBuffer, offset, baseSizes, baseStrides = /// extract_strided_metadata(memref) -/// sizes#reassIdx = -/// baseSizes#reassDim / product(expandShapeSizes#j, -/// for j in group excluding reassIdx) -/// strides#reassIdx = -/// baseStrides#reassDim * product(expandShapeSizes#j, for j in -/// reassIdx+1..reassIdx+group.size-1) +/// sizes = getReshapedSizes(reshapeLike) +/// strides = getReshapedStrides(reshapeLike) /// \endverbatim /// -/// Where reassIdx is a reassociation index for the group at reassDim -/// and expandShapeSizes#j is either: -/// - The constant size at dimension j, derived directly from the result type of -/// the expand_shape op, or -/// - An affine expression: baseSizes#reassDim / product of all constant sizes -/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic -/// element.) /// /// Notice that `baseBuffer` and `offset` are unchanged. /// /// In other words, get rid of the expand_shape in that expression and /// materialize its effects on the sizes and the strides using affine apply. -struct ExtractStridedMetadataOpExpandShapeFolder +template (*getReshapedSizes)( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, unsigned /*groupId*/), + SmallVector (*getReshapedStrides)( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, + ArrayRef /*origStrides*/, unsigned /*groupId*/)> +struct ExtractStridedMetadataOpReshapeFolder : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, PatternRewriter &rewriter) const override { - auto expandShape = op.getSource().getDefiningOp(); - if (!expandShape) + auto reshape = op.getSource().getDefiningOp(); + if (!reshape) return failure(); // Build a plain extract_strided_metadata(memref) from - // extract_strided_metadata(expand_shape(memref)). + // extract_strided_metadata(reassociative_reshape_like(memref)). Location origLoc = op.getLoc(); IndexType indexType = rewriter.getIndexType(); - Value source = expandShape.getSrc(); + Value source = reshape.getSrc(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); SmallVector sizeStrideTypes(sourceRank, indexType); @@ -374,13 +504,13 @@ succeeded(getStridesAndOffset(sourceType, strides, offset)); (void)hasKnownStridesAndOffset; assert(hasKnownStridesAndOffset && - "getStridesAndOffset must work on valid expand_shape"); - MemRefType expandShapeType = expandShape.getResultType(); - unsigned expandShapeRank = expandShapeType.getRank(); + "getStridesAndOffset must work on valid reassociative_reshape_like"); + MemRefType reshapeType = reshape.getResultType(); + unsigned reshapeRank = reshapeType.getRank(); // The result value will start with the base_buffer and offset. unsigned baseIdxInResult = 2; - SmallVector results(baseIdxInResult + expandShapeRank * 2); + SmallVector results(baseIdxInResult + reshapeRank * 2); results[0] = newExtractStridedMetadata.getBaseBuffer(); results[1] = ShapedType::isDynamicStrideOrOffset(offset) ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) @@ -390,7 +520,7 @@ if (sourceRank == 0) { Value constantOne = getValueOrCreateConstantIndexOp( rewriter, origLoc, rewriter.getIndexAttr(1)); - SmallVector resultValues(baseIdxInResult + expandShapeRank * 2, + SmallVector resultValues(baseIdxInResult + reshapeRank * 2, constantOne); for (unsigned i = 0; i < baseIdxInResult; ++i) resultValues[i] = @@ -399,30 +529,31 @@ return success(); } - // Compute the expanded strides and sizes from the base strides and sizes. + // Compute the reshaped strides and sizes from the base strides and sizes. SmallVector origSizes = getAsOpFoldResult(newExtractStridedMetadata.getSizes()); SmallVector origStrides = getAsOpFoldResult(newExtractStridedMetadata.getStrides()); - unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size(); + unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); for (; idx != endIdx; ++idx) { - SmallVector expandedSizes = - getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx); - SmallVector expandedStrides = getExpandedStrides( - expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx); + SmallVector reshapedSizes = + getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); + SmallVector reshapedStrides = getReshapedStrides( + reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); - unsigned groupSize = expandShape.getReassociationIndices()[idx].size(); + unsigned groupSize = reshapedSizes.size(); const unsigned sizeStartIdx = baseIdxInResult; - const unsigned strideStartIdx = sizeStartIdx + expandShapeRank; + const unsigned strideStartIdx = sizeStartIdx + reshapeRank; for (unsigned i = 0; i < groupSize; ++i) { - results[sizeStartIdx + i] = expandedSizes[i]; - results[strideStartIdx + i] = expandedStrides[i]; + results[sizeStartIdx + i] = reshapedSizes[i]; + results[strideStartIdx + i] = reshapedStrides[i]; } baseIdxInResult += groupSize; } - assert(idx == sourceRank && + assert(((isa(reshape) && idx == sourceRank) || + (isa(reshape) && idx == reshapeRank)) && "We should have visited all the input dimensions"); - assert(baseIdxInResult == expandShapeRank + 2 && + assert(baseIdxInResult == reshapeRank + 2 && "We should have populated all the values"); rewriter.replaceOp( op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results)); @@ -599,12 +730,17 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns.add, - ExtractStridedMetadataOpAllocFolder, - RewriteExtractAlignedPointerAsIndexOfViewLikeOp>( - patterns.getContext()); + patterns + .add, + ExtractStridedMetadataOpReshapeFolder< + memref::CollapseShapeOp, getCollapsedSize, getCollapsedStride>, + ForwardStaticMetadata, + ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpAllocFolder, + RewriteExtractAlignedPointerAsIndexOfViewLikeOp>( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -762,3 +762,82 @@ %r = memref.extract_aligned_pointer_as_index %arg0: memref -> index return %r : index } + +// ----- + +// Check that we simplify extract_strided_metadata of collapse_shape. +// +// We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5] +// Size 0 = origSize0 +// Size 1 = origSize1 * origSize2 * origSize3 +// = origSize1 * 4 * origSize3 +// Size 2 = origSize4 * origSize5 +// = 6 * 7 +// = 42 +// Stride 0 = origStride0 +// Stride 1 = origStride3 (orig stride of the inner most dimension) +// = 42 +// Stride 2 = origStride5 +// = 1 +// +// CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-LABEL: func @extract_strided_metadata_of_collapse( +// CHECK-SAME: %[[ARG:.*]]: memref) +// +// CHECK-DAG: %[[C42:.*]] = arith.constant 42 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref +// +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// +// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]] +func.func @extract_strided_metadata_of_collapse(%arg : memref) + -> (memref, index, + index, index, index, + index, index, index) { + + %collapsed_view = memref.collapse_shape %arg [[0], [1, 2, 3], [4, 5]] : + memref into memref + + %base, %offset, %sizes:3, %strides:3 = + memref.extract_strided_metadata %collapsed_view : memref + -> memref, index, + index, index, index, + index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, + %strides#0, %strides#1, %strides#2 : + memref, index, + index, index, index, + index, index, index + +} + +// ----- + +// Check that we simplify extract_strided_metadata of collapse_shape to +// a 0-ranked shape. +// CHECK-LABEL: func @extract_strided_metadata_of_collapse_to_rank0( +// CHECK-SAME: %[[ARG:.*]]: memref<1x1x1x1x1x1xi32>) +// +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref<1x1x1x1x1x1xi32> +// +// CHECK: return %[[BASE]], %[[C0]] +func.func @extract_strided_metadata_of_collapse_to_rank0(%arg : memref<1x1x1x1x1x1xi32>) + -> (memref, index) { + + %collapsed_view = memref.collapse_shape %arg [] : + memref<1x1x1x1x1x1xi32> into memref + + %base, %offset = + memref.extract_strided_metadata %collapsed_view : memref + -> memref, index + + return %base, %offset : + memref, index +}