diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -43,7 +43,7 @@ /// \endverbatim /// /// In other words, get rid of the subview in that expression and canonicalize -/// on its effects on the offset, the sizes, and the strides using affine apply. +/// on its effects on the offset, the sizes, and the strides using affine.apply. struct ExtractStridedMetadataOpSubviewFolder : public OpRewritePattern { public: @@ -166,11 +166,185 @@ return success(); } }; + +/// Replace `baseBuffer, offset, sizes, strides = +/// extract_strided_metadata(expand_shape(memref))` +/// With +/// +/// \verbatim +/// baseBuffer, offset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// sizes#reassIdx = +/// baseSizes#reassDim / product(expandShapeSizes#j, +/// for j in group excluding reassIdx) +/// strides#reassIdx = +/// baseStrides#reassDim * product(expandShapeSizes#j, for j in +/// reassIdx+1..reassIdx+group.size-1) +/// \endverbatim +/// +/// Where reassIdx is a reassociation index for the group at reassDim +/// and expandShapeSizes#j is either: +/// - The constant size at dimension j, derived directly from the result type of +/// the expand_shape op, or +/// - An affine expression: baseSizes#reassDim / product of all constant sizes +/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic +/// element.) +/// +/// Notice that `baseBuffer` and `offset` are unchanged. +/// +/// In other words, get rid of the expand_shape in that expression and +/// materialize its effects on the sizes and the strides using affine apply. +struct ExtractStridedMetadataOpExpandShapeFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto expandShape = op.getSource().getDefiningOp(); + if (!expandShape) + return failure(); + + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(expand_shape(memref)). + Location origLoc = op.getLoc(); + IndexType indexType = rewriter.getIndexType(); + Value source = expandShape.getSrc(); + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + SmallVector sizeStrideTypes(sourceRank, indexType); + + auto newExtractStridedMetadata = + rewriter.create( + origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, + sizeStrideTypes, source); + + // Collect statically known information. + SmallVector strides; + int64_t offset; + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, strides, offset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid expand_shape"); + MemRefType expandShapeType = expandShape.getResultType(); + unsigned expandShapeRank = expandShapeType.getRank(); + + // The result value will start with the base_buffer and offset. + unsigned baseIdxInResult = 2; + SmallVector results(baseIdxInResult + expandShapeRank * 2); + results[0] = newExtractStridedMetadata.getBaseBuffer(); + results[1] = ShapedType::isDynamicStrideOrOffset(offset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(offset); + + // Compute the expanded strides and sizes from the base strides and sizes. + + // Get the special case of 0-D out of the way. + if (sourceRank == 0) { + Value constantOne = getValueOrCreateConstantIndexOp( + rewriter, origLoc, rewriter.getIndexAttr(1)); + SmallVector resultValues(baseIdxInResult + expandShapeRank * 2, + constantOne); + for (unsigned i = 0; i < baseIdxInResult; ++i) + resultValues[i] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]); + rewriter.replaceOp(op, resultValues); + return success(); + } + + // Keep track of which source dimension this group applies to. + unsigned reassDim = 0; + for (SmallVector groups : + expandShape.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + unsigned groupSize = groups.size(); + + // Compute the strides in two steps: + // First account only for the static part: + // stride[i] = prod(size#j, for j in 0..i-1) + // Second, if we saw a dynamic dimension, patch up all the strides up to + // this dimension with the dynamic size. + // The dynamic size is equal to origSize / prod(staticSize of this group). + uint64_t currentStride = 1; + OpFoldResult baseStride = + ShapedType::isDynamicStrideOrOffset(strides[reassDim]) + ? getAsOpFoldResult( + newExtractStridedMetadata.getStrides()[reassDim]) + : rewriter.getIndexAttr(strides[reassDim]); + Optional dynSizeIdx; + SmallVector expandedStrides(groupSize); + SmallVector expandedSizes(groupSize); + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + for (int i = groupSize - 1; i >= 0; --i) { + expandedStrides[i] = makeComposedFoldedAffineApply( + rewriter, origLoc, s0 * currentStride, baseStride); + uint64_t dimSize = expandShapeType.getDimSize(groups[i]); + if (ShapedType::isDynamic(dimSize)) { + dynSizeIdx = i; + continue; + } + + expandedSizes[i] = rewriter.getIndexAttr(dimSize); + + currentStride *= dimSize; + } + if (dynSizeIdx) { + int64_t productOfKnownStaticSizes = currentStride; + assert(ShapedType::isDynamic(sourceType.getDimSize(reassDim)) && + "We shouldn't be able to change dynamicity"); + OpFoldResult baseSize = newExtractStridedMetadata.getSizes()[reassDim]; + expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply( + rewriter, origLoc, s0.floorDiv(productOfKnownStaticSizes), + baseSize); + // Now update the stride of all the dimensions that are before the + // dynamic size. + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + // We are about to compute: + // expandedSize * stride#j. + // But we know that: + // - expandedSize = oldSize / prod(Size#i, for i in 0..N-1 and Size#i is + // static) + // - stride#j = prod(Size#i, for i in j+1..N-1, and Size#i is static) + // So in the end, what we want is: + // oldSize / prod(Size#i, for i in 0..j) + int64_t foldedStrideWithSize = 1; + for (int64_t i = 0; i < *dynSizeIdx; ++i) { + uint64_t dimSize = expandShapeType.getDimSize(groups[i]); + assert(!ShapedType::isDynamic(dimSize) && + "At most one dimension can be dynamic"); + foldedStrideWithSize *= dimSize; + expandedStrides[i] = makeComposedFoldedAffineApply( + rewriter, origLoc, s0.floorDiv(foldedStrideWithSize) * s1, + {baseSize, baseStride}); + } + } + + const unsigned sizeStartIdx = baseIdxInResult; + const unsigned strideStartIdx = sizeStartIdx + expandShapeRank; + for (unsigned i = 0; i < groupSize; ++i) { + results[sizeStartIdx + i] = expandedSizes[i]; + results[strideStartIdx + i] = expandedStrides[i]; + } + baseIdxInResult += groupSize; + ++reassDim; + } + assert(reassDim == sourceRank && + "We should have visited all the input dimensions"); + assert(baseIdxInResult == expandShapeRank + 2 && + "We should have populated all the values"); + rewriter.replaceOp( + op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results)); + return success(); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -281,3 +281,231 @@ return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : memref, index, index, index, index, index, index, index } + +// ----- + +// Check that we properly simplify extract_strided_metadata of expand_shape +// into: +// baseBuffer, baseOffset, baseSizes, baseStrides = +// extract_strided_metadata(memref) +// sizes#reassIdx = +// baseSizes#reassDim / product(expandShapeSizes#j, +// for j in group excluding reassIdx) +// strides#reassIdx = +// baseStrides#reassDim * product(expandShapeSizes#j, for j in +// reassIdx+1..reassIdx+group.size) +// +// Here we have: +// For the group applying to dim0: +// size 0 = 3 +// size 1 = 5 +// size 2 = 2 +// stride 0 = baseStrides#0 * 5 * 2 +// = 4 * 5 * 2 +// = 40 +// stride 1 = baseStrides#0 * 2 +// = 4 * 2 +// = 8 +// stride 2 = baseStrides#0 +// = 4 +// +// For the group applying to dim1: +// size 3 = 2 +// size 4 = 2 +// stride 3 = baseStrides#1 * 2 +// = 1 * 2 +// = 2 +// stride 4 = baseStrides#1 +// = 1 +// +// Base and offset are unchanged. +// +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static +// CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>) +// +// CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref, index, index, index, index, index +// +// CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_static( + %arg : memref<30x4xi16>) + -> (memref, index, + index, index, index, index, index, + index, index, index, index, index) { + + %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] : + memref<30x4xi16> into memref<3x5x2x2x2xi16> + + %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : + memref<3x5x2x2x2xi16> + -> memref, index, + index, index, index, index, index, + index, index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : + memref, index, + index, index, index, index, index, + index, index, index, index, index +} + +// ----- + +// Check that we properly simplify extract_strided_metadata of expand_shape +// when dynamic sizes, strides, and offsets are involved. +// See extract_strided_metadata_of_expand_shape_all_static for an explanation +// of the expansion. +// +// One of the important characteristic of this test is that the dynamic +// dimensions produced by the expand_shape appear both in the first dimension +// (for group 1) and the non-first dimension (second dimension for group 2.) +// The idea is to make sure that: +// 1. We properly account for dynamic shapes even when the strides are not +// affected by them. (When the dynamic dimension is the first one.) +// 2. We properly compute the strides affected by dynamic shapes. (When the +// dynamic dimension is not the first one.) +// +// Here we have: +// For the group applying to dim0: +// size 0 = baseSizes#0 / (all static sizes in that group) +// = baseSizes#0 / (7 * 8 * 9) +// = baseSizes#0 / 504 +// size 1 = 7 +// size 2 = 8 +// size 3 = 9 +// stride 0 = baseStrides#0 * 7 * 8 * 9 +// = baseStrides#0 * 504 +// stride 1 = baseStrides#0 * 8 * 9 +// = baseStrides#0 * 72 +// stride 2 = baseStrides#0 * 9 +// stride 3 = baseStrides#0 +// +// For the group applying to dim1: +// size 4 = 10 +// size 5 = 2 +// size 6 = baseSizes#1 / (all static sizes in that group) +// = baseSizes#1 / (10 * 2 * 3) +// = baseSizes#1 / 60 +// size 7 = 3 +// stride 4 = baseStrides#1 * size 5 * size 6 * size 7 +// = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3 +// = baseStrides#1 * (baseSizes#1 / 60) * 6 +// and since we know that baseSizes#1 is a multiple of 60: +// = baseStrides#1 * (baseSizes#1 / 10) +// stride 5 = baseStrides#1 * size 6 * size 7 +// = baseStrides#1 * (baseSizes#1 / 60) * 3 +// = baseStrides#1 * (baseSizes#1 / 20) +// stride 6 = baseStrides#1 * size 7 +// = baseStrides#1 * 3 +// stride 7 = baseStrides#1 +// +// Base and offset are unchanged. +// +// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> +// CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)> +// +// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> +// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> +// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> +// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> +// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)> +// CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic +// CHECK-SAME: (%[[ARG:.*]]: memref> -> memref, index, index, index, index, index +// +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] +// CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1] +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1] + +// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref, index, index, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_dynamic( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> (memref, index, + index, index, index, index, index, index, index, index, + index, index, index, index, index, index, index, index) { + + %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6, 7]] : + memref> into + memref> + + %base_buffer, %offset, %sizes:8, %strides:8 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, + index, index, index, index, index, index, index, index, + index, index, index, index, index, index, index, index + + return %base_buffer, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, %sizes#7, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6, %strides#7 : + memref, index, + index, index, index, index, index, index, index, index, + index, index, index, index, index, index, index, index +} + + +// ----- + +// Check that we properly handle extract_strided_metadata of expand_shape for +// 0-D input. +// The 0-D case is pretty boring: +// All expanded sizes are 1, likewise for the strides, and we keep the +// original base and offset. +// We have still a test for it, because since the input reassociation map +// of the expand_shape is empty, the handling of such shape hits a corner +// case. +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static_0_rank +// CHECK-SAME: (%[[ARG:.*]]: memref>) +// +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]] = memref.extract_strided_metadata %[[ARG]] : memref> -> memref, index +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]], %[[C1]] : memref, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_static_0_rank( + %arg : memref>) + -> (memref, index, + index, index, index, index, index, + index, index, index, index, index) { + + %expand_shape = memref.expand_shape %arg[] : + memref> into memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> + + %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : + memref<1x1x1x1x1xi16, strided<[1,1,1,1,1], offset: ?>> + -> memref, index, + index, index, index, index, index, + index, index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : + memref, index, + index, index, index, index, index, + index, index, index, index, index +}