diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" @@ -422,14 +423,63 @@ assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid collapse_shape"); - SmallVector collapsedStride; - int64_t innerMostDimForGroup = reassocGroup.back(); - int64_t innerMostStrideForGroup = strides[innerMostDimForGroup]; - collapsedStride.push_back( - ShapedType::isDynamic(innerMostStrideForGroup) - ? origStrides[innerMostDimForGroup] - : builder.getIndexAttr(innerMostStrideForGroup)); + SmallVector groupStrides; + ArrayRef srcShape = sourceType.getShape(); + for (int64_t currentDim : reassocGroup) { + // Skip size-of-1 dimensions, since right now their strides may be + // meaningless. + // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless + // they are truly contiguous. When they are truly contiguous, we shouldn't + // need to skip them. + if (srcShape[currentDim] == 1) + continue; + int64_t currentStride = strides[currentDim]; + groupStrides.push_back(ShapedType::isDynamic(currentStride) + ? origStrides[currentDim] + : builder.getIndexAttr(currentStride)); + } + SmallVector collapsedStride; + if (groupStrides.empty()) { + // We're dealing with a 1x1x...x1 shape. The stride is meaningless, + // but we still have to make the type system happy. + MemRefType collapsedType = collapseShape.getResultType(); + SmallVector collapsedStrides; + int64_t collapsedOffset; + bool hasKnownCollapsedStridesAndOffset = succeeded( + getStridesAndOffset(collapsedType, collapsedStrides, collapsedOffset)); + (void)hasKnownCollapsedStridesAndOffset; + assert(hasKnownStridesAndOffset && + "getStridesAndOffset must work on valid collapse_shape"); + int64_t finalStride = collapsedStrides[groupId]; + if (ShapedType::isDynamic(finalStride)) { + // Look for a dynamic stride. At this point we don't know which one is + // desired, but they are all equally good/bad. + for (int64_t currentDim : reassocGroup) { + assert(srcShape[currentDim] == 1 && + "We should be dealing with 1x1x...x1"); + + int64_t currentStride = strides[currentDim]; + if (ShapedType::isDynamic(currentStride)) { + collapsedStride.push_back(origStrides[currentDim]); + break; + } + } + assert(!collapsedStride.empty() && + "We should have found a dynamic stride"); + } else + collapsedStride.push_back(builder.getIndexAttr(finalStride)); + } else if (groupStrides.size() == 1) { + // TODO: affine.min should probably fold `min(a)` => `a`. + collapsedStride.push_back(groupStrides[0]); + } else { + // For the general case, we just want the minimum stride + // since the collapsed dimensions are contiguous. + auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(), + builder.getContext()); + collapsedStride.push_back(makeComposedFoldedAffineMin( + builder, collapseShape.getLoc(), minMap, groupStrides)); + } return collapsedStride; } /// Replace `baseBuffer, offset, sizes, strides = diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -908,20 +908,23 @@ // = 6 * 7 // = 42 // Stride 0 = origStride0 -// Stride 1 = origStride3 (orig stride of the inner most dimension) -// = 42 -// Stride 2 = origStride5 +// Stride 1 = min(origStride1, origStride2, origStride3) +// = min(origStride1, origStride2, 42) +// Stride 2 = min(origStride4, origStride5) +// = min(7, 1) // = 1 // // CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)> // CHECK-LABEL: func @simplify_collapse( // CHECK-SAME: %[[ARG:.*]]: memref) // // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref // -// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2] // -// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1] +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, %[[DYN_STRIDE1]], 1] func.func @simplify_collapse(%arg : memref) -> memref { @@ -934,6 +937,118 @@ // ----- +// Check that we simplify collapse_shape into +// reinterpret_cast(extract_strided_metadata) + +// when there are dimensions of size 1 involved. +// +// We transform: 3x1 to [0, 1] +// +// The tricky bit here is the strides between dimension 0 and 1 +// are not truly contiguous, but since we dealing with a dimension of size 1 +// this is actually fine (i.e., we are not going to jump around.) +// +// As a result the resulting stride needs to ignore the strides of the +// dimensions of size 1. +// +// Size 0 = origSize0 * origSize1 +// = 3 * 1 +// = 3 +// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1) +// = min(origStride0) +// = min(2) +// = 2 +// +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1( +// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>, +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>> +// +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2] +func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : + memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>> + + memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32> + + return +} + + +// ----- + +// Check that we simplify collapse_shape with an edge case group of 1x1x...x1. +// +// The tricky bit here is also the resulting stride is meaningless, we still +// have to please the type system. +// +// In this case, we're collapsing two strides of respectively 2 and 1 and the +// resulting type wants a stride of 2. +// +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_non_1_stride( +// CHECK-SAME: %[[ARG:.*]]: memref<1x1xi32, strided<[2, 1] +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<1x1xi32, strided<[2, 1], offset: ?>> +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [1], strides: [2] +func.func @simplify_collapse_with_dim_of_size1_and_non_1_stride + (%arg0: memref<1x1xi32, strided<[2, 1], offset: ?>>) + -> memref<1xi32, strided<[2], offset: ?>> { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : + memref<1x1xi32, strided<[2, 1], offset: ?>> + into memref<1xi32, strided<[2], offset: ?>> + + return %collapse_shape : memref<1xi32, strided<[2], offset: ?>> +} + +// ----- + +// Check that we simplify collapse_shape with an edge case group of 1x1x...x1. +// We also have a couple of collapsed dimensions before the 1x1x...x1 group +// to make sure we properly index into the dynamic strides based on the +// group ID. +// +// The tricky bit in this test is that the 1x1x...x1 group stride is dynamic +// so we have to propagate one of the dynamic dimension for this group. +// +// For this test we have: +// Size0 = origSize0 * origSize1 +// = 2 * 3 +// = 6 +// Size1 = origSize2 * origSize3 * origSize4 +// = 1 * 1 * 1 +// = 1 +// +// Stride0 = min(origStride0, origStride1) +// Stride1 = we actually don't know, this is dynamic but we don't know +// which one to pick. +// We just return the first dynamic one for this group. +// +// +// CHECK-DAG: #[[$STRIDE0_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1)> +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride( +// CHECK-SAME: %[[ARG:.*]]: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2] +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:5, %[[STRIDES:.*]]:5 = memref.extract_strided_metadata %[[ARG]] : memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> +// +// CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.min #[[$STRIDE0_MIN_MAP]]()[%[[STRIDES]]#0, %[[STRIDES]]#1] +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [6, 1], strides: [%[[DYN_STRIDE0]], %[[STRIDES]]#2] +func.func @simplify_collapse_with_dim_of_size1_and_resulting_dyn_stride + (%arg0: memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>>) + -> memref<6x1xi32, strided<[?, ?], offset: ?>> { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1], [2, 3, 4]] : + memref<2x3x1x1x1xi32, strided<[?, ?, ?, ?, 2], offset: ?>> + into memref<6x1xi32, strided<[?, ?], offset: ?>> + + return %collapse_shape : memref<6x1xi32, strided<[?, ?], offset: ?>> +} + +// ----- + // Check that we simplify extract_strided_metadata of collapse_shape. // // We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]