diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -166,11 +166,229 @@ return success(); } }; + +/// Replace `baseBuffer, offset, sizes, strides = +/// extract_strided_metadata(expand_shape(memref))` +/// With +/// +/// \verbatim +/// baseBuffer, offset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// sizes#reassIdx = +/// baseSizes#reassDim / product(expandShapeSizes#j, +/// for j in group excluding reassIdx) +/// strides#reassIdx = +/// baseStrides#reassDim * product(expandShapeSizes#j, for j in +/// reassIdx+1..reassIdx+group.size-1) +/// \endverbatim +/// +/// Where reassIdx is a reassociation index for the group at reassDim +/// and expandShapeSizes#j is either: +/// - The constant size at dimension j, derived directly from the result type of +/// the expand_shape op, or +/// - An affine expression: baseSizes#reassDim / product of all constant sizes +/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic +/// element.) +/// +/// Notice that `baseBuffer` and `offset` are unchanged. +/// +/// In other words, get rid of the expand_shape in that expression and +/// materialize its effects on the sizes and the strides using affine apply. +struct ExtractStridedMetadataOpExpandShapeFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto expandShape = op.getSource().getDefiningOp(); + if (!expandShape) + return failure(); + + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(expand_shape(memref)). + Location origLoc = op.getLoc(); + IndexType indexType = rewriter.getIndexType(); + Value source = expandShape.getSrc(); + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + SmallVector sizeStrideTypes(sourceRank, indexType); + + auto newExtractStridedMetadata = + rewriter.create( + origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, + sizeStrideTypes, source); + + // Collect statically known information. + SmallVector strides; + int64_t offset; + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, strides, offset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid expand_shape"); + MemRefType expandShapeType = expandShape.getResultType(); + unsigned expandShapeRank = expandShapeType.getRank(); + + // The result value will start with the base_buffer and offset. + unsigned baseIdxInResult = 2; + SmallVector results(baseIdxInResult + expandShapeRank * 2); + results[0] = newExtractStridedMetadata.getBaseBuffer(); + results[1] = ShapedType::isDynamicStrideOrOffset(offset) + ? newExtractStridedMetadata.getOffset() + : getValueOrCreateConstantIndexOp( + rewriter, origLoc, rewriter.getIndexAttr(offset)); + + // Compute the new strides and sizes from the base strides and sizes. + + // Keep track of which source dimension this group applies to. + unsigned reassDim = 0; + for (SmallVector groups : + expandShape.getReassociationIndices()) { + assert(!groups.empty() && "association indices groups cannot be empty"); + unsigned groupSize = groups.size(); + + // for each reassociation group: + // let reassDim be the source dimension that this group applies to. + // for each reassIdx in group + // newSize#reassIdx = + // baseSize#reassDim / product(expandShapeSize#j, + // for j in group excluding reassIdx) + // newStride#reassIdx = + // baseStride#reassDim * product(expandShapeSize#j, for j in + // reassIdx+1..reassIdx+group.size-1) + // Note: newSize#reassIdx is actually already available at + // expandShapeType.getDimSize(reassIdx) unless the size is dynamic. + // Therefore in practice we skip the computation for all sizes but dynamic + // ones. + // Given there is at most one dynamic size per group, we end up collecting + // the product of all the static sizes in one variable and use it to + // compute the dynamic size of this one dimension. + bool hasDynamicSize = false; + uint64_t productOfKnownStaticSize = 1; + + // Calculate product(expandShapeSize#j, for j in + // reassIdx+1..reassIdx+group.size-1) for each dimension of that group. + // This will be used to compute the final strides. + // Do the computation in reverse order of the dimensions to reuse the + // product of the previous dimension as a partial product for the new one. + // Note: we iterate on index 0 as well to collect the information about + // the dynamic shapes and static sizes for all dimensions. + SmallVector suffixProductOfSizes(groupSize, 1); + for (int i = groupSize - 1; i >= 0; --i) { + uint64_t dimSize = expandShapeType.getDimSize(groups[i]); + if (ShapedType::isDynamic(dimSize)) { + assert(!hasDynamicSize && + "At most one dimension should be dynamic per group"); + hasDynamicSize = true; + // The proper dynamic size will be applied when we build affine + // expression later. + dimSize = 1; + } + productOfKnownStaticSize *= dimSize; + if (i != 0) + suffixProductOfSizes[i - 1] = suffixProductOfSizes[i] * dimSize; + } + + OpFoldResult baseStride = + ShapedType::isDynamicStrideOrOffset(strides[reassDim]) + ? getAsOpFoldResult( + newExtractStridedMetadata.getStrides()[reassDim]) + : rewriter.getIndexAttr(strides[reassDim]); + uint64_t sourceDimSize = sourceType.getDimSize(reassDim); + OpFoldResult baseSize = + ShapedType::isDynamic(sourceDimSize) + ? getAsOpFoldResult( + newExtractStridedMetadata.getSizes()[reassDim]) + : rewriter.getIndexAttr(sourceDimSize); + + // Value and expression holding the contribution of the dynamic size of + // the group in the computation of the strides. + // We need this value for all strides, which dimension appears before + // this dynamic dimension. + // E.g., assume a shape of 3x4x?x6. The dynamic size will appear in the + // computation of the stride of the first and second dimension (resp. + // 4x?x6 and ?x6.) + // We also carry the current stride in the expression to allow proper + // constant folding. + OpFoldResult dynSizeInput; + AffineExpr dynSizeTimesCurrentStrideExpr; + AffineExpr dynSizeSymbol, suffixProdOfSizesSymbol, baseStrideSymbol; + MLIRContext *context = rewriter.getContext(); + bindSymbols(context, dynSizeSymbol, suffixProdOfSizesSymbol, + baseStrideSymbol); + AffineExpr basicDynSizeTimesCurrentStrideExpr = + dynSizeSymbol * suffixProdOfSizesSymbol; + + if (hasDynamicSize) { + AffineExpr productOfKnownStaticSizeExpr = + getAffineConstantExpr(productOfKnownStaticSize, context); + // We use the following expression to generate things like + // origSize.floordiv(10), instead of origSize.floorDiv(30) * 3. (Which + // would have been the direct translation of dynSize * stride, where + // dynSize == origSize / 30.) We achieve that thanks to constant + // propagation in the expression: originSize.floordiv(30.floordiv(3)). + dynSizeTimesCurrentStrideExpr = dynSizeSymbol.floorDiv( + productOfKnownStaticSizeExpr.floorDiv(suffixProdOfSizesSymbol)); + dynSizeInput = baseSize; + } else { + // This value will be used in a product. + // Use the neutral elment when we don't have any dynamic size invloved. + dynSizeInput = rewriter.getIndexAttr(1); + dynSizeTimesCurrentStrideExpr = basicDynSizeTimesCurrentStrideExpr; + } + + for (unsigned i = 0; i < groupSize; ++i) { + + int64_t newSize = expandShapeType.getDimSize(groups[i]); + OpFoldResult newSizeOFR; + OpFoldResult suffixProductOfSizesAttr = + rewriter.getIndexAttr(suffixProductOfSizes[i]); + if (ShapedType::isDynamic(newSize)) { + assert(hasDynamicSize && + "Dynamic size should have been captured in the previous loop"); + + newSizeOFR = makeComposedFoldedAffineApply( + rewriter, origLoc, + dynSizeSymbol.floorDiv(productOfKnownStaticSize), {dynSizeInput}); + // Now that we saw the dynamic size, the computations of the following + // strides, including the current dimension, don't need to account for + // it: Use the neutral elment for the product. + dynSizeInput = rewriter.getIndexAttr(1); + dynSizeTimesCurrentStrideExpr = basicDynSizeTimesCurrentStrideExpr; + } else + newSizeOFR = rewriter.getIndexAttr(newSize); + + results[baseIdxInResult] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, newSizeOFR); + + AffineExpr strideExpr = + dynSizeTimesCurrentStrideExpr * baseStrideSymbol; + results[baseIdxInResult + expandShapeRank] = + getValueOrCreateConstantIndexOp( + rewriter, origLoc, + makeComposedFoldedAffineApply( + rewriter, origLoc, strideExpr, + {dynSizeInput, suffixProductOfSizesAttr, baseStride})); + ++baseIdxInResult; + } + ++reassDim; + } + assert(reassDim == sourceRank && + "We should have visited all the input dimensions"); + assert(baseIdxInResult == expandShapeRank + 2 && + "We should have populated all the values"); + rewriter.replaceOp(op, results); + return success(); + } +}; } // namespace void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -281,3 +281,181 @@ return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : memref, index, index, index, index, index, index, index } + +// ----- + +// Check that we properly simplify extract_strided_metadata of expand_shape +// into: +// baseBuffer, baseOffset, baseSizes, baseStrides = +// extract_strided_metadata(memref) +// sizes#reassIdx = +// baseSizes#reassDim / product(expandShapeSizes#j, +// for j in group excluding reassIdx) +// strides#reassIdx = +// baseStrides#reassDim * product(expandShapeSizes#j, for j in +// reassIdx+1..reassIdx+group.size) +// +// Here we have: +// For the group applying to dim0: +// size 0 = 3 +// size 1 = 5 +// size 2 = 2 +// stride 0 = baseStrides#0 * 5 * 2 +// = 4 * 5 * 2 +// = 40 +// stride 1 = baseStrides#0 * 2 +// = 4 * 2 +// = 8 +// stride 2 = baseStrides#0 +// = 4 +// +// For the group applying to dim1: +// size 3 = 2 +// size 4 = 2 +// stride 3 = baseStrides#1 * 2 +// = 1 * 2 +// = 2 +// stride 4 = baseStrides#1 +// = 1 +// +// Base and offset are unchanged. +// +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_static +// CHECK-SAME: (%[[ARG:.*]]: memref<30x4xi16>) +// +// CHECK-DAG: %[[C40:.*]] = arith.constant 40 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<30x4xi16> -> memref, index, index, index, index, index +// +// CHECK: return %[[BASE]], %[[C0]], %[[C3]], %[[C5]], %[[C2]], %[[C2]], %[[C2]], %[[C40]], %[[C8]], %[[C4]], %[[C2]], %[[C1]] : memref, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_static( + %arg : memref<30x4xi16>) + -> (memref, index, + index, index, index, index, index, + index, index, index, index, index) { + + %expand_shape = memref.expand_shape %arg[[0, 1, 2], [3, 4]] : + memref<30x4xi16> into memref<3x5x2x2x2xi16> + + %base, %offset, %sizes:5, %strides:5 = memref.extract_strided_metadata %expand_shape : + memref<3x5x2x2x2xi16> + -> memref, index, + index, index, index, index, index, + index, index, index, index, index + + return %base, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4 : + memref, index, + index, index, index, index, index, + index, index, index, index, index +} + +// ----- + +// Check that we properly simplify extract_strided_metadata of expand_shape +// when dynamic sizes, strides, and offsets are involved. +// See extract_strided_metadata_of_expand_shape_all_static for an explanation +// of the expansion. +// +// One of the important characteristic of this test is that the dynamic +// dimensions produced by the expand_shape appear both in the first dimension +// (for group 1) and the non-first dimension (second dimension for group 2.) +// The idea is to make sure that: +// 1. We properly account for dynamic shapes even when the strides are not +// affected by them. (When the dynamic dimension is the first one.) +// 2. We properly compute the strides affected by dynamic shapes. (When the +// dynamic dimension is not the first one.) +// +// Here we have: +// For the group applying to dim0: +// size 0 = baseSizes#0 / (all static sizes in that group) +// = baseSizes#0 / (7 * 8 * 9) +// = baseSizes#0 / 504 +// size 1 = 7 +// size 2 = 8 +// size 3 = 9 +// stride 0 = baseStrides#0 * 7 * 8 * 9 +// = baseStrides#0 * 504 +// stride 1 = baseStrides#0 * 8 * 9 +// = baseStrides#0 * 72 +// stride 2 = baseStrides#0 * 9 +// stride 3 = baseStrides#0 +// +// For the group applying to dim1: +// size 4 = 10 +// size 5 = baseSizes#1 / (all static sizes in that group) +// = baseSizes#1 / (10 * 3) +// = baseSizes#1 / 30 +// size 6 = 3 +// stride 4 = baseStrides#1 * size 5 * size 6 +// = baseStrides#1 * (baseSizes#1 / 30) * 3 +// and since we know that baseSizes#1 is a multiple of 30: +// = baseStrides#1 * (baseSizes#1 / 10) +// stride 5 = baseStrides#1 * size 6 +// = baseStrides#1 * 3 +// stride 6 = baseStrides#1 +// +// Base and offset are unchanged. +// +// CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)> +// CHECK-DAG: #[[$DIM5_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 30)> +// +// CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)> +// CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)> +// CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)> +// CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)> +// CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic +// CHECK-SAME: (%[[ARG:.*]]: memref> -> memref, index, index, index, index, index +// +// CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0] +// CHECK-DAG: %[[DYN_SIZE5:.*]] = affine.apply #[[$DIM5_SIZE_MAP]]()[%[[SIZES]]#1] +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[STRIDES]]#1] + +// CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[DYN_SIZE5]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[STRIDES]]#1 : memref, index, index, index, index, index, index, index, index, index, index, index, index, index +func.func @extract_strided_metadata_of_expand_shape_all_dynamic( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> (memref, index, + index, index, index, index, index, index, index, + index, index, index, index, index, index, index) { + + %subview = memref.expand_shape %base[[0, 1, 2, 3],[4, 5, 6]] : + memref> into + memref> + + %base_buffer, %offset, %sizes:7, %strides:7 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, + index, index, index, index, index, index, index, + index, index, index, index, index, index, index + + return %base_buffer, %offset, + %sizes#0, %sizes#1, %sizes#2, %sizes#3, %sizes#4, %sizes#5, %sizes#6, + %strides#0, %strides#1, %strides#2, %strides#3, %strides#4, %strides#5, %strides#6 : + memref, index, + index, index, index, index, index, index, index, + index, index, index, index, index, index, index +}