diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -55,6 +55,11 @@ /// terms of shapes of its input operands. void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns); +/// Appends patterns for simplifying extract_strided_metadata(other_op) into +/// easier to analyze constructs. +void populateSimplifyExtractStridedMetadataOpPatterns( + RewritePatternSet &patterns); + /// Transformation to do multi-buffering/array expansion to remove dependencies /// on the temporary allocation between consecutive loop iterations. /// It return success if the allocation was multi-buffered and returns failure() @@ -118,6 +123,11 @@ /// in terms of shapes of its input operands. std::unique_ptr createResolveShapedTypeResultDimsPass(); +/// Creates an operation pass to simplify +/// `extract_strided_metadata(other_op(memref))` into +/// `extract_strided_metadata(memref)`. +std::unique_ptr createSimplifyExtractStridedMetadataPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -173,5 +173,18 @@ ]; } +def SimplifyExtractStridedMetadata : Pass<"simplify-extract-strided-metadata"> { + let summary = "Simplify extract_strided_metadata ops"; + let description = [{ + The pass simplifies extract_strided_metadata(other_op(memref)) to + extract_strided_metadata(memref) when it is possible to model the effect + of other_op directly with affine maps applied to the result of + extract_strided_metadata. + }]; + let constructor = "mlir::memref::createSimplifyExtractStridedMetadataPass()"; + let dependentDialects = [ + "AffineDialect", "memref::MemRefDialect" + ]; +} #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -320,6 +320,14 @@ e = getAffineSymbolExpr(N, ctx); bindSymbols(ctx, exprs...); } + +template +void bindSymbolsList(MLIRContext *ctx, SmallVectorImpl &exprs) { + int idx = 0; + for (AffineExprTy &e : exprs) + e = getAffineSymbolExpr(idx++, ctx); +} + } // namespace detail /// Bind a list of AffineExpr references to DimExpr at positions: diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp + SimplifyExtractStridedMetadata.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef diff --git a/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp @@ -0,0 +1,211 @@ +//===- SimplifyExtractStridedMetadata.cpp - Simplify this operation -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This pass simplifies extract_strided_metadata(other_op(memref) to +/// extract_strided_metadata(memref) when it is possible to express the effect +// of other_op using affine apply on the results of +// extract_strided_metadata(memref). +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_SIMPLIFYEXTRACTSTRIDEDMETADATA +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir +using namespace mlir; + +namespace { +/// Replace `baseBuffer, offset, sizes, strides = +/// extract_strided_metadata(subview(memref, subOffset, +/// subSizes, subStrides))` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * strides#i) +/// sizes = subSizes +/// \endverbatim +/// +/// In other words, get rid of the subview in that expression and canonicalize +/// on its effects on the offset, the sizes, and the strides using affine apply. +struct ExtractStridedMetadataOpSubviewFolder + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto subview = op.getSource().getDefiningOp(); + if (!subview) + return failure(); + + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(subview(memref)). + Location origLoc = op.getLoc(); + IndexType indexType = rewriter.getIndexType(); + Value source = subview.getSource(); + auto sourceType = source.getType().cast(); + unsigned sourceRank = sourceType.getRank(); + SmallVector sizeStrideTypes(sourceRank, indexType); + + auto newExtractStridedMetadata = + rewriter.create( + origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes, + sizeStrideTypes, source); + + SmallVector sourceStrides; + int64_t sourceOffset; + + bool hasKnownStridesAndOffset = + succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)); + (void)hasKnownStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid subviews"); + // Helper function to get either the value of the newly created + // extract_strided_metadata or the altValue if that one is valid and is a + // constant. + auto getAlternativeValue = [&rewriter, + origLoc](OpFoldResult opr, + int64_t altValue) -> OpFoldResult { + if (!ShapedType::isDynamicStrideOrOffset(altValue)) + return getAsOpFoldResult( + rewriter.create(origLoc, altValue)); + return opr; + }; + + // Compute the new strides and offset from the base strides and offset: + // newStride#i = baseStride#i * subStride#i + // offset = baseOffset + sum(subOffsets#i * newStrides#i) + SmallVector strides; + SmallVector subStrides = subview.getMixedStrides(); + auto origStrides = newExtractStridedMetadata.getStrides(); + auto getOrigStrideAtIdx = [&getAlternativeValue, &origStrides, + &sourceStrides](unsigned idx) { + return getAlternativeValue(origStrides[idx], sourceStrides[idx]); + }; + + // Compute the strides. + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + for (unsigned i = 0; i < sourceRank; ++i) + strides.push_back(makeComposedFoldedAffineApply( + rewriter, origLoc, s0 * s1, {subStrides[i], getOrigStrideAtIdx(i)})); + + // Compute the offset. + SmallVector values(3 * sourceRank + 1); + SmallVector symbols(3 * sourceRank + 1); + + detail::bindSymbolsList(rewriter.getContext(), symbols); + AffineExpr expr = symbols.front(); + values[0] = getAlternativeValue( + getAsOpFoldResult(newExtractStridedMetadata.getOffset()), sourceOffset); + SmallVector subOffsets = subview.getMixedOffsets(); + for (unsigned i = 0; i < sourceRank; ++i) { + unsigned baseIdxForDim = 1 + 3 * i; + unsigned subOffsetForDim = baseIdxForDim; + unsigned subStrideForDim = baseIdxForDim + 1; + unsigned origStrideForDim = baseIdxForDim + 2; + expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] * + symbols[origStrideForDim]; + values[subOffsetForDim] = subOffsets[i]; + values[subStrideForDim] = subStrides[i]; + values[origStrideForDim] = getOrigStrideAtIdx(i); + } + + OpFoldResult finalOffset = + makeComposedFoldedAffineApply(rewriter, origLoc, expr, values); + + SmallVector results; + // The final result is . + // Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all + // the values. + auto subType = subview.getType().cast(); + unsigned subRank = subType.getRank(); + // Properly size the arrray so that we can do random insertions + // at the right indices. + // We do that to populate the non-dropped sizes and strides in one go. + results.resize_for_overwrite(subRank * 2 + 2); + + results[0] = newExtractStridedMetadata.getBaseBuffer(); + results[1] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset); + + // The sizes of the final type are defined directly by the input sizes of + // the subview. + // Moreover subviews can drop some dimensions, some strides and sizes may + // not end up in the final value that we are + // replacing. + // Do the filtering here. + SmallVector subSizes = subview.getMixedSizes(); + const unsigned sizeStartIdx = 2; + const unsigned strideStartIdx = sizeStartIdx + subRank; + unsigned insertedDims = 0; + llvm::SmallBitVector droppedDims = subview.getDroppedDims(); + for (unsigned i = 0; i < sourceRank; ++i) { + if (droppedDims.test(i)) + continue; + + results[sizeStartIdx + insertedDims] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]); + results[strideStartIdx + insertedDims] = + getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]); + ++insertedDims; + } + assert(insertedDims == subRank && + "Should have populated all the values at this point"); + + rewriter.replaceOp(op, results); + return success(); + } +}; +} // namespace + +void memref::populateSimplifyExtractStridedMetadataOpPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { + +struct SimplifyExtractStridedMetadataPass final + : public memref::impl::SimplifyExtractStridedMetadataBase< + SimplifyExtractStridedMetadataPass> { + void runOnOperation() override; +}; + +} // namespace + +void SimplifyExtractStridedMetadataPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + memref::populateSimplifyExtractStridedMetadataOpPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), + std::move(patterns)); +} + +std::unique_ptr memref::createSimplifyExtractStridedMetadataPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir @@ -0,0 +1,283 @@ +// RUN: mlir-opt --simplify-extract-strided-metadata -split-input-file %s -o - | FileCheck %s + +// ----- + +// Check that we simplify extract_strided_metadata of subview to +// base_buf, base_offset, base_sizes, base_strides = extract_strided_metadata +// strides = base_stride_i * subview_stride_i +// offset = base_offset + sum(subview_offsets_i * strides_i). +// +// This test also checks that we don't create useless arith operations +// when subview_offsets_i is 0. +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview +// CHECK-SAME: (%[[ARG:.*]]: memref<5x4xf32>) +// +// Materialize the offset for dimension 1. +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// +// Plain extract_strided_metadata. +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// Final offset is: +// origOffset + (== 0) +// base_stride0 * subview_stride0 * subview_offset0 + (== 4 * 1 * 0 == 0) +// base_stride1 * subview_stride1 * subview_offset1 (== 1 * 1 * 2) +// == 2 +// +// Return the new tuple. +// CHECK: return %[[BASE]], %[[C2]], %[[C2]], %[[C2]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_subview(%base: memref<5x4xf32>) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %base[0, 2][2, 2][1, 1] : + memref<5x4xf32> to memref<2x2xf32, strided<[4, 1], offset: 2>> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<2x2xf32, strided<[4,1], offset:2>> + -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when dynamic sizes are involved. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// Orig strides: [64, 4, 1] +// Sub strides: [1, 1, 1] +// => New strides: [64, 4, 1] +// +// Orig offset: 0 +// Sub offsets: [3, 4, 2] +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// +// Final sizes == subview sizes == [%size, 6, 3] +// +// CHECK-LABEL: func @extract_strided_metadata_of_subview_with_dynamic_size +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: %[[DYN_SIZE:.*]]: index) +// +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C210]], %[[DYN_SIZE]], %[[C6]], %[[C3]], %[[C64]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_subview_with_dynamic_size( + %base: memref<8x16x4xf32>, %size: index) + -> (memref, index, index, index, index, index, index, index) { + + %subview = memref.subview %base[3, 4, 2][%size, 6, 3][1, 1, 1] : + memref<8x16x4xf32> to memref> + + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, index, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : + memref, index, index, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when the subview reduces the ranks. +// In particular the returned strides must come from #1 and #2 of the %strides +// value of the new extract_strided_metadata_of_subview, not #0 and #1. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// Orig strides: [64, 4, 1] +// Sub strides: [1, 1, 1] +// => New strides: [64, 4, 1] +// Final stides == filterOutReducedDim(new strides, 0) == [4 , 1] +// +// Orig offset: 0 +// Sub offsets: [3, 4, 2] +// => Final offset: 3 * 64 + 4 * 4 + 2 * 1 + 0 == 210 +// +// Final sizes == filterOutReducedDim(subview sizes, 0) == [6, 3] +// +// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>) +// +// CHECK-DAG: %[[C210:.*]] = arith.constant 210 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C210]], %[[C6]], %[[C3]], %[[C4]], %[[C1]] +func.func @extract_strided_metadata_of_rank_reduced_subview(%base: memref<8x16x4xf32>) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, 1, 1] : + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<6x3xf32, strided<[4,1], offset: 210>> + -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when the subview reduces the rank and some of the strides are variable. +// In particular, we check that: +// A. The dynamic stride is multiplied with the base stride to create the new +// stride for dimension 1. +// B. The first returned stride is the value computed in #A. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// Orig strides: [64, 4, 1] +// Sub strides: [1, %stride, 1] +// => New strides: [64, 4 * %stride, 1] +// Final stides == filterOutReducedDim(new strides, 0) == [4 * %stride , 1] +// +// Orig offset: 0 +// Sub offsets: [3, 4, 2] +// => Final offset: 3 * 64 + 4 * 4 * %stride + 2 * 1 + 0 == 16 * %stride + 194 +// +// CHECK-DAG: #[[$STRIDE1_MAP:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0] -> (s0 * 16 + 194)> +// CHECK-LABEL: func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides +// CHECK-SAME: (%[[ARG:.*]]: memref<8x16x4xf32>, +// CHECK-SAME: %[[DYN_STRIDE:.*]]: index) +// +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[DIM1_STRIDE:.*]] = affine.apply #[[$STRIDE1_MAP]]()[%[[DYN_STRIDE]]] +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[DYN_STRIDE]]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C6]], %[[C3]], %[[DIM1_STRIDE]], %[[C1]] +func.func @extract_strided_metadata_of_rank_reduced_subview_w_variable_strides( + %base: memref<8x16x4xf32>, %stride: index) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %base[3, 4, 2][1, 6, 3][1, %stride, 1] : + memref<8x16x4xf32> to memref<6x3xf32, strided<[4, 1], offset: 210>> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<6x3xf32, strided<[4, 1], offset: 210>> + -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of subview properly +// when the subview uses variable offsets. +// See extract_strided_metadata_of_subview for an explanation of the actual +// expansion. +// +// Orig strides: [128, 1] +// Sub strides: [1, 1] +// => New strides: [128, 1] +// +// Orig offset: 0 +// Sub offsets: [%arg1, %arg2] +// => Final offset: 128 * arg1 + 1 * %arg2 + 0 +// +// CHECK-DAG: #[[$OFFSETS_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 128 + s1)> +// CHECK-LABEL: func @extract_strided_metadata_of_subview_w_variable_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<384x128xf32>, +// CHECK-SAME: %[[DYN_OFFSET0:.*]]: index, +// CHECK-SAME: %[[DYN_OFFSET1:.*]]: index) +// +// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSETS_MAP]]()[%[[DYN_OFFSET0]], %[[DYN_OFFSET1]]] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[C64]], %[[C64]], %[[C128]], %[[C1]] +#map0 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)> +func.func @extract_strided_metadata_of_subview_w_variable_offset( + %arg0: memref<384x128xf32>, %arg1 : index, %arg2 : index) + -> (memref, index, index, index, index, index) { + + %subview = memref.subview %arg0[%arg1, %arg2] [64, 64] [1, 1] : + memref<384x128xf32> to memref<64x64xf32, #map0> + + %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %subview : + memref<64x64xf32, #map0> -> memref, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : + memref, index, index, index, index, index +} + +// ----- + +// Check that all the math is correct for all types of computations. +// We achieve that by using dynamic values for all the different types: +// - Offsets +// - Sizes +// - Strides +// +// Orig strides: [s0, s1, s2] +// Sub strides: [subS0, subS1, subS2] +// => New strides: [s0 * subS0, s1 * subS1, s2 * subS2] +// ==> 1 affine map (used for each stride) with two values. +// +// Orig offset: origOff +// Sub offsets: [subO0, subO1, subO2] +// => Final offset: s0 * subS0 * subO0 + ... + s2 * subS2 * subO2 + origOff +// ==> 1 affine map with (rank * 3 + 1) symbols +// +// CHECK-DAG: #[[$STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK-DAG: #[[$OFFSET_MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0 + (s1 * s2) * s3 + (s4 * s5) * s6 + (s7 * s8) * s9)> +// CHECK-LABEL: func @extract_strided_metadata_of_subview_all_dynamic +// CHECK-SAME: (%[[ARG:.*]]: memref>, %[[DYN_OFFSET0:.*]]: index, %[[DYN_OFFSET1:.*]]: index, %[[DYN_OFFSET2:.*]]: index, %[[DYN_SIZE0:.*]]: index, %[[DYN_SIZE1:.*]]: index, %[[DYN_SIZE2:.*]]: index, %[[DYN_STRIDE0:.*]]: index, %[[DYN_STRIDE1:.*]]: index, %[[DYN_STRIDE2:.*]]: index) +// +// CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK-DAG: %[[FINAL_STRIDE0:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE0]], %[[STRIDES]]#0] +// CHECK-DAG: %[[FINAL_STRIDE1:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE1]], %[[STRIDES]]#1] +// CHECK-DAG: %[[FINAL_STRIDE2:.*]] = affine.apply #[[$STRIDE_MAP]]()[%[[DYN_STRIDE2]], %[[STRIDES]]#2] +// +// CHECK-DAG: %[[FINAL_OFFSET:.*]] = affine.apply #[[$OFFSET_MAP]]()[%[[OFFSET]], %[[DYN_OFFSET0]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_OFFSET1]], %[[DYN_STRIDE1]], %[[STRIDES]]#1, %[[DYN_OFFSET2]], %[[DYN_STRIDE2]], %[[STRIDES]]#2] +// +// CHECK: return %[[BASE]], %[[FINAL_OFFSET]], %[[DYN_SIZE0]], %[[DYN_SIZE1]], %[[DYN_SIZE2]], %[[FINAL_STRIDE0]], %[[FINAL_STRIDE1]], %[[FINAL_STRIDE2]] +func.func @extract_strided_metadata_of_subview_all_dynamic( + %base: memref>, + %offset0: index, %offset1: index, %offset2: index, + %size0: index, %size1: index, %size2: index, + %stride0: index, %stride1: index, %stride2: index) + -> (memref, index, index, index, index, index, index, index) { + + %subview = memref.subview %base[%offset0, %offset1, %offset2] + [%size0, %size1, %size2] + [%stride0, %stride1, %stride2] : + memref> to + memref> + + %base_buffer, %offset, %sizes:3, %strides:3 = memref.extract_strided_metadata %subview : + memref> + -> memref, index, index, index, index, index, index, index + + return %base_buffer, %offset, %sizes#0, %sizes#1, %sizes#2, %strides#0, %strides#1, %strides#2 : + memref, index, index, index, index, index, index, index +}