diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1823,11 +1823,20 @@ return failure(); // The result strides are exactly the strides of the last entry of each - // reassociation. + // reassociation. The only exception is when the dim of last entry is of size + // 1, the result stride should be of the next entry which isn't size 1. + // TODO: Consider memref.transpose case where the strides are not sorted. SmallVector resultStrides; resultStrides.reserve(reassociation.size()); - for (ReassociationIndices reassoc : reassociation) - resultStrides.push_back(srcStrides[reassoc.back()]); + for (const ReassociationIndices &reassoc : reassociation) { + ArrayRef ref = llvm::makeArrayRef(reassoc); + while (srcShape[ref.back()] == 1 && ref.size() > 1) + ref = ref.drop_back(); + if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) + resultStrides.push_back(srcStrides[ref.back()]); + else + resultStrides.push_back(ShapedType::kDynamicStrideOrOffset); + } // Validate that each reassociation group is contiguous. unsigned resultStrideIndex = resultStrides.size() - 1; diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -331,14 +331,14 @@ func @do_not_compose_collapse_of_expand_non_identity_layout( %arg0: memref) - -> memref { + -> memref { %1 = memref.expand_shape %arg0 [[0, 1], [2]] : memref into memref %2 = memref.collapse_shape %1 [[0, 1, 2]] : memref into - memref - return %2 : memref + memref + return %2 : memref } // CHECK-LABEL: func @do_not_compose_collapse_of_expand_non_identity_layout // CHECK: expand diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,11 +1,16 @@ // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)> -// CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> + // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> + // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> + // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> + // CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0) -> (d0 * 2)> + // CHECK-DAG: #[[$MAP5:.*]] = affine_map<(d0) -> (d0 + 1)> + // CHECK-DAG: #[[$MAP6:.*]] = affine_map<() -> (1)> + // CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)> + // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> + // CHECK-DAG: #[[$MAP9:.*]] = affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)> + // CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, @@ -338,17 +343,6 @@ return %1 : tensor } -// CHECK-LABEL: func @tensor.expand_shape_of_slice2( -// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> -func @tensor.expand_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> { - // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP5]]> - %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> - // CHECK: memref.collapse_shape %{{.*}} [ - // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP5]]> into memref<1xf32> - %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> - return %1 : tensor<1xf32> -} - // CHECK-LABEL: func @tensor.collapse_shape( // CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor { @@ -378,9 +372,45 @@ // CHECK-LABEL: func @tensor.collapse_shape_of_slice( func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor { - // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP3]]> + // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP5]]> %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> - // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP3]]> into memref + // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP5]]> into memref %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor return %1 : tensor } + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice2( +// CHECK-SAME: %[[t1:.*]]: tensor<1x2xf32> +func @tensor.collapse_shape_of_slice2(%t1: tensor<1x2xf32>) -> tensor<1xf32> { + // CHECK: memref.subview {{.*}} : memref<1x2xf32> to memref<1x1xf32, #[[$MAP3]]> + %0 = tensor.extract_slice %t1[0, 0][1, 1][1, 1] : tensor<1x2xf32> to tensor<1x1xf32> + // CHECK: memref.collapse_shape %{{.*}} [ + // CHECK-SAME: [0, 1]] : memref<1x1xf32, #[[$MAP3]]> into memref<1xf32, #[[$MAP4]]> + %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x1xf32> into tensor<1xf32> + return %1 : tensor<1xf32> +} + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice3( +// CHECK-SAME: %[[t1:.*]]: tensor, +// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<8xf32> { +func @tensor.collapse_shape_of_slice3(%arg0: tensor, %offset: index, %size: index) -> tensor<8xf32> { + // CHECK: memref.subview %{{.*}} : memref to memref<4x2x1xf32, #[[$MAP7]]> + %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, 2, 1] [1, 1, 1] : tensor to tensor<4x2x1xf32> + // CHECK: memref.collapse_shape %{{.*}} [ + // CHECK-SAME: [0, 1, 2]] : memref<4x2x1xf32, #[[$MAP7]]> into memref<8xf32, #[[$MAP8]]> + %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32> + return %ret: tensor<8xf32> +} + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice4( +// CHECK-SAME: %[[t1:.*]]: tensor, +// CHECK-SAME: %[[OFFSET:.*]]: index, +// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor { +func @tensor.collapse_shape_of_slice4(%arg0: tensor, %offset: index, %size: index) -> tensor { + // CHECK: memref.subview %{{.*}} : memref to memref<4x?x1xf32, #[[$MAP9]]> + %0 = tensor.extract_slice %arg0[0, 0, %offset] [4, %size, 1] [1, 1, 1] : tensor to tensor<4x?x1xf32> + // CHECK: memref.collapse_shape %{{.*}} [ + // CHECK-SAME: [0, 1, 2]] : memref<4x?x1xf32, #[[$MAP9]]> into memref + %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x?x1xf32> into tensor + return %ret: tensor +}