diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" @@ -422,14 +423,38 @@ assert(hasKnownStridesAndOffset && "getStridesAndOffset must work on valid collapse_shape"); - SmallVector collapsedStride; - int64_t innerMostDimForGroup = reassocGroup.back(); - int64_t innerMostStrideForGroup = strides[innerMostDimForGroup]; - collapsedStride.push_back( - ShapedType::isDynamic(innerMostStrideForGroup) - ? origStrides[innerMostDimForGroup] - : builder.getIndexAttr(innerMostStrideForGroup)); + SmallVector groupStrides; + ArrayRef srcShape = sourceType.getShape(); + for (int64_t currentDim : reassocGroup) { + // Skip size-of-1 dimensions, since right now their strides may be + // meaningless. + // FIXME: size-of-1 dimensions shouldn't be used in collapse shape, unless + // they are truly contiguous. When they are truly contiguous, we shouldn't + // need to skip them. + if (srcShape[currentDim] == 1) + continue; + int64_t currentStride = strides[currentDim]; + groupStrides.push_back(ShapedType::isDynamic(currentStride) + ? origStrides[currentDim] + : builder.getIndexAttr(currentStride)); + } + SmallVector collapsedStride; + if (groupStrides.empty()) { + // We're dealing with a 1x1x...x1 shape. The stride is meaningless. + // Just put one. + collapsedStride.push_back(builder.getIndexAttr(1)); + } else if (groupStrides.size() == 1) { + // TODO: affine.min should probably fold `min(a)` => `a`. + collapsedStride.push_back(groupStrides[0]); + } else { + // For the general case, we just want the minimum stride + // since the collapsed dimensions are contiguous. + auto minMap = AffineMap::getMultiDimIdentityMap(groupStrides.size(), + builder.getContext()); + collapsedStride.push_back(makeComposedFoldedAffineMin( + builder, collapseShape.getLoc(), minMap, groupStrides)); + } return collapsedStride; } /// Replace `baseBuffer, offset, sizes, strides = diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -908,20 +908,23 @@ // = 6 * 7 // = 42 // Stride 0 = origStride0 -// Stride 1 = origStride3 (orig stride of the inner most dimension) -// = 42 -// Stride 2 = origStride5 +// Stride 1 = min(origStride1, origStride2, origStride3) +// = min(origStride1, origStride2, 42) +// Stride 2 = min(origStride4, origStride5) +// = min(7, 1) // = 1 // // CHECK-DAG: #[[$SIZE0_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 * s1) * 4)> +// CHECK-DAG: #[[$STRIDE1_MIN_MAP:.*]] = affine_map<()[s0, s1] -> (s0, s1, 42)> // CHECK-LABEL: func @simplify_collapse( // CHECK-SAME: %[[ARG:.*]]: memref) // // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:6, %[[STRIDES:.*]]:6 = memref.extract_strided_metadata %[[ARG]] : memref // -// CHECK: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// CHECK-DAG: %[[DYN_SIZE1:.*]] = affine.apply #[[$SIZE0_MAP]]()[%[[SIZES]]#1, %[[SIZES]]#3] +// CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.min #[[$STRIDE1_MIN_MAP]]()[%[[STRIDES]]#1, %[[STRIDES]]#2] // -// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, 42, 1] +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [%[[SIZES]]#0, %[[DYN_SIZE1]], 42], strides: [%[[STRIDES]]#0, %[[DYN_STRIDE1]], 1] func.func @simplify_collapse(%arg : memref) -> memref { @@ -934,6 +937,46 @@ // ----- +// Check that we simplify collapse_shape into +// reinterpret_cast(extract_strided_metadata) + +// when there are dimensions of size 1 involved. +// +// We transform: 3x1 to [0, 1] +// +// The tricky bit here is the strides between dimension 0 and 1 +// are not truly contiguous, but since we dealing with a dimension of size 1 +// this is actually fine (i.e., we are not going to jump around.) +// +// As a result the resulting stride needs to ignore the strides of the +// dimensions of size 1. +// +// Size 0 = origSize0 * origSize1 +// = 3 * 1 +// = 3 +// Stride 0 = min(origStride_i, for all i in reassocation group and dim_i != 1) +// = min(origStride0) +// = min(2) +// = 2 +// +// CHECK-LABEL: func @simplify_collapse_with_dim_of_size1( +// CHECK-SAME: %[[ARG:.*]]: memref<3x1xf32, strided<[2, 1]>>, +// +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<3x1xf32, strided<[2, 1]>> +// +// +// CHECK: %[[COLLAPSE_VIEW:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [0], sizes: [3], strides: [2] +func.func @simplify_collapse_with_dim_of_size1(%arg0: memref<3x1xf32, strided<[2,1]>>, %arg1: memref<3xf32>) { + + %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : + memref<3x1xf32, strided<[2, 1]>> into memref<3xf32, strided<[2]>> + + memref.copy %collapse_shape, %arg1 : memref<3xf32, strided<[2]>> to memref<3xf32> + + return +} + +// ----- + // Check that we simplify extract_strided_metadata of collapse_shape. // // We transform: ?x?x4x?x6x7xi32 to [0][1,2,3][4,5]