diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -30,9 +30,7 @@ using namespace mlir; namespace { -/// Replace `baseBuffer, offset, sizes, strides = -/// extract_strided_metadata(subview(memref, subOffset, -/// subSizes, subStrides))` +/// Replace `dst = subview(memref, subOffset, subSizes, subStrides))` /// With /// /// \verbatim @@ -41,24 +39,19 @@ /// strides#i = baseStrides#i * subSizes#i /// offset = baseOffset + sum(subOffset#i * strides#i) /// sizes = subSizes +/// dst = reinterpret_cast baseBuffer, offset, sizes, strides /// \endverbatim /// /// In other words, get rid of the subview in that expression and canonicalize /// on its effects on the offset, the sizes, and the strides using affine.apply. -struct ExtractStridedMetadataOpSubviewFolder - : public OpRewritePattern { +struct SubviewFolder : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + LogicalResult matchAndRewrite(memref::SubViewOp subview, PatternRewriter &rewriter) const override { - auto subview = op.getSource().getDefiningOp(); - if (!subview) - return failure(); - - // Build a plain extract_strided_metadata(memref) from - // extract_strided_metadata(subview(memref)). - Location origLoc = op.getLoc(); + // Build a plain extract_strided_metadata(memref) from subview(memref). + Location origLoc = subview.getLoc(); Value source = subview.getSource(); auto sourceType = source.getType().cast(); unsigned sourceRank = sourceType.getRank(); @@ -120,20 +113,11 @@ OpFoldResult finalOffset = makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); - SmallVector results; // The final result is . // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all // the values. auto subType = subview.getType().cast(); unsigned subRank = subType.getRank(); - // Properly size the array so that we can do random insertions - // at the right indices. - // We do that to populate the non-dropped sizes and strides in one go. - results.resize_for_overwrite(subRank * 2 + 2); - - results[0] = newExtractStridedMetadata.getBaseBuffer(); - results[1] = - getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset); // The sizes of the final type are defined directly by the input sizes of // the subview. @@ -142,24 +126,30 @@ // replacing. // Do the filtering here. SmallVector subSizes = subview.getMixedSizes(); - const unsigned sizeStartIdx = 2; - const unsigned strideStartIdx = sizeStartIdx + subRank; - unsigned insertedDims = 0; llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + + SmallVector finalSizes; + finalSizes.reserve(subRank); + + SmallVector finalStrides; + finalStrides.reserve(subRank); + for (unsigned i = 0; i < sourceRank; ++i) { if (droppedDims.test(i)) continue; - results[sizeStartIdx + insertedDims] = - getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]); - results[strideStartIdx + insertedDims] = - getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]); - ++insertedDims; + finalSizes.push_back(subSizes[i]); + finalStrides.push_back(strides[i]); } - assert(insertedDims == subRank && + assert(finalSizes.size() == subRank && "Should have populated all the values at this point"); - rewriter.replaceOp(op, results); + auto memrefDesc = rewriter.create( + origLoc, subType, newExtractStridedMetadata.getBaseBuffer(), + finalOffset, + /*sizes=*/finalSizes, + /*strides=*/finalStrides); + rewriter.replaceOp(subview, memrefDesc.getResult()); return success(); } }; @@ -773,7 +763,7 @@ void memref::populateSimplifyExtractStridedMetadataOpPatterns( RewritePatternSet &patterns) { patterns - .add, ExtractStridedMetadataOpReshapeFolder< diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -21,6 +21,57 @@ // ----- +// Check that we simplify subview(src) into: +// base, offset, sizes, strides xtract_strided_metadata src +// final_sizes = subSizes +// final_strides = strides +// final_offset = offset +// reinterpret_cast base to final_offset, final_sizes, final_ strides +// +// Orig strides: [s0, s1, s2] +// Sub strides: [subS0, subS1, subS2] +// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2] +// ==> 1 affine map (used for each stride) with two values. +// +// Orig offset: origOff +// Sub offsets: [subO0, subO1, subO2] +// => Final offset: s0 * subS0 * subO0 + ... + s2 * subS2 * subO2 + origOff +// ==> 1 affine map with (rank * 3 + 1) symbols +// +// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0 + (s1 * s2) * s3 + (s4 * s5) * s6 + (s7 * s8) * s9)> +// CHECK-LABEL: func @simplify_subview_all_dynamic +// CHECK-SAME: (%[[ARG:.*]]: memref>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index) +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0] +// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1] +// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2] +// +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[DYN_STRIDE1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[DYN_STRIDE2]], %[[STRIDES]]#2] +// +// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[FINAL_OFFSET]]], sizes: [%[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]]], strides: [%[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]]] +// +// CHECK: return %[[RES]] +func.func @simplify_subview_all_dynamic( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> memref> { + + %subview = memref.subview %base[%offset0, %offset1, %offset2] + [%size0, %size1, %size2] + [%stride0, %stride1, %stride2] : + memref> to + memref> + + return %subview : memref> +} + +// ----- + // Check that we simplify extract_strided_metadata of subview to // base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata // strides = base_stride_i * subview_stride_i