diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" @@ -403,15 +404,50 @@ auto [strides, offset] = getStridesAndOffset(sourceType); - SmallVector collapsedStride; - int64_t innerMostDimForGroup = reassocGroup.back(); - int64_t innerMostStrideForGroup = strides[innerMostDimForGroup]; - collapsedStride.push_back( - ShapedType::isDynamic(innerMostStrideForGroup) - ? origStrides[innerMostDimForGroup] - : builder.getIndexAttr(innerMostStrideForGroup)); + SmallVector groupStrides; + ArrayRef srcShape = sourceType.getShape(); + for (int64_t currentDim : reassocGroup) { + // Skip size-of-1 dimensions, since right now their strides may be + // meaningless. + // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless + // they are truly contiguous. When they are truly contiguous, we shouldn't + // need to skip them. + if (srcShape[currentDim] == 1) + continue; + + int64_t currentStride = strides[currentDim]; + groupStrides.push_back(ShapedType::isDynamic(currentStride) + ? origStrides[currentDim] + : builder.getIndexAttr(currentStride)); + } + if (groupStrides.empty()) { + // We're dealing with a 1x1x...x1 shape. The stride is meaningless, + // but we still have to make the type system happy. + MemRefType collapsedType = collapseShape.getResultType(); + auto [collapsedStrides, collapsedOffset] = + getStridesAndOffset(collapsedType); + int64_t finalStride = collapsedStrides[groupId]; + if (ShapedType::isDynamic(finalStride)) { + // Look for a dynamic stride. At this point we don't know which one is + // desired, but they are all equally good/bad. + for (int64_t currentDim : reassocGroup) { + assert(srcShape[currentDim] == 1 && + "We should be dealing with 1x1x...x1"); + + if (ShapedType::isDynamic(strides[currentDim])) + return {origStrides[currentDim]}; + } + llvm_unreachable("We should have found a dynamic stride"); + } + return {builder.getIndexAttr(finalStride)}; + } - return collapsedStride; + // For the general case, we just want the minimum stride + // since the collapsed dimensions are contiguous. + auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(), + builder.getContext()); + return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, + groupStrides)}; } /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -907,21 +907,28 @@ // Size 2 = origSize4 * origSize5 // = 6 * 7 // = 42 -// Stride 0 = origStride0 -// Stride 1 = origStride3 (orig stride of the inner most dimension) -// = 42 -// Stride 2 = origStride5 +// Stride 0 = min(origStride0) +// = Right now the folder of affine.min is not smart +// enough to just return origStride0 +// Stride 1 = min(origStride1, origStride2, origStride3) +// = min(origStride1, origStride2, 42) +// Stride 2 = min(origStride4, origStride5) +// = min(7, 1) // = 1 // +// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0] -> (s0)> // CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)> // CHECK-LABEL: func @simplify_collapse( // CHECK-SAME: %[[ARG:.*]]: memref) // // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref // -// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2] // -// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1] +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], 1] func.func @simplify_collapse(%arg : memref) -> memref { @@ -934,6 +941,118 @@ // ----- +// Check that we simplify collapse_shape into +// reinterpret_cast(extract_strided_metadata) + +// when there are dimensions of size 1 involved. +// +// We transform: 3x1 to [0, 1] +// +// The tricky bit here is the strides between dimension 0 and 1 +// are not truly contiguous, but since we dealing with a dimension of size 1 +// this is actually fine (i.e., we are not going to jump around.) +// +// As a result the resulting stride needs to ignore the strides of the +// dimensions of size 1. +// +// Size 0 = origSize0 * origSize1 +// = 3 * 1 +// = 3 +// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1) +// = min(origStride0) +// = min(2) +// = 2 +// +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1( +// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>, +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>> +// +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2] +func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : + memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>> + + memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32> + + return +} + + +// ----- + +// Check that we simplify collapse_shape with an edge case group of 1x1x...x1. +// +// The tricky bit here is also the resulting stride is meaningless, we still +// have to please the type system. +// +// In this case, we're collapsing two strides of respectively 2 and 1 and the +// resulting type wants a stride of 2. +// +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_non_1_stride( +// CHECK-SAME: %[[ARG:.*]]: memref<1x1xi32, strided<[2, 1] +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<1x1xi32, strided<[2, 1], offset: ?>> +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [2] +func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride + (%arg0: memref<1x1xi32, strided<[2, 1], offset: ?>>) + -> memref<1xi32, strided<[2], offset: ?>> { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : + memref<1x1xi32, strided<[2, 1], offset: ?>> + into memref<1xi32, strided<[2], offset: ?>> + + return %collapse_shape : memref<1xi32, strided<[2], offset: ?>> +} + +// ----- + +// Check that we simplify collapse_shape with an edge case group of 1x1x...x1. +// We also have a couple of collapsed dimensions before the 1x1x...x1 group +// to make sure we properly index into the dynamic strides based on the +// group ID. +// +// The tricky bit in this test is that the 1x1x...x1 group stride is dynamic +// so we have to propagate one of the dynamic dimension for this group. +// +// For this test we have: +// Size0 = origSize0 * origSize1 +// = 2 * 3 +// = 6 +// Size1 = origSize2 * origSize3 * origSize4 +// = 1 * 1 * 1 +// = 1 +// +// Stride0 = min(origStride0, origStride1) +// Stride1 = we actually don't know, this is dynamic but we don't know +// which one to pick. +// We just return the first dynamic one for this group. +// +// +// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)> +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride( +// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2] +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> +// +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1] +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2] +func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride + (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>) + -> memref<6x1xi32, strided<[?, ?], offset: ?>> { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2, 3, 4]] : + memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> + into memref<6x1xi32, strided<[?, ?], offset: ?>> + + return %collapse_shape : memref<6x1xi32, strided<[?, ?], offset: ?>> +} + +// ----- + // Check that we simplify extract_strided_metadata of collapse_shape. // // We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5] @@ -950,6 +1069,7 @@ // = 1 // // CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-DAG: #[[$STRIDE0_MAP:.*]] = affine_map<()[s0] -> (s0)> // CHECK-LABEL: func @extract_strided_metadata_of_collapse( // CHECK-SAME: %[[ARG:.*]]: memref) // @@ -959,9 +1079,10 @@ // // CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref // +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MAP]]()[%[[STRIDES]]#0] // CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] // -// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[STRIDES]]#0, %[[C42]], %[[C1]] +// CHECK: return %[[BASE]], %[[C0]], %[[SIZES]]#0, %[[DYN_SIZE1]], %[[C42]], %[[DYN_STRIDE0]], %[[C42]], %[[C1]] func.func @extract_strided_metadata_of_collapse(%arg : memref) -> (memref, index, index, index, index,