diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1734,9 +1734,10 @@ .asStride(); } } - return makeStridedLinearLayoutMap( - llvm::to_vector<8>(llvm::reverse(reverseResultStrides)), srcOffset, - srcType.getContext()); + auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides)); + resultStrides.resize(resultShape.size(), 1); + return makeStridedLinearLayoutMap(resultStrides, srcOffset, + srcType.getContext()); } static FailureOr diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -9,6 +9,8 @@ // CHECK-DAG: #[[$MAP6:.*]] = affine_map<(d0) -> (d0 * 2)> // CHECK-DAG: #[[$MAP7:.*]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 8 + s0 + d1 * 4 + d2)> // CHECK-DAG: #[[$MAP8:.*]] = affine_map<(d0)[s0] -> (d0 * 4 + s0)> + // CHECK-DAG: #[[$MAP9:.*]] = affine_map<()[s0] -> (s0)> + // CHECK-DAG: #[[$MAP10:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, @@ -333,6 +335,20 @@ return %1 : tensor } +// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice( +// CHECK-SAME: %[[t1:.*]]: tensor +func.func @tensor.expand_shape_of_scalar_slice( + %t1: tensor, %o1: index, %s1: index) -> tensor<1xf32> { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}] [1] [1] : memref to memref + %0 = tensor.extract_slice %t1[%o1][1][1] : tensor to tensor + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [] : memref into memref<1xf32, #[[$MAP10]]> + %1 = tensor.expand_shape %0 [] : tensor into tensor<1xf32> + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] + // CHECK: return %[[r]] + return %1 : tensor<1xf32> +} + // CHECK-LABEL: func @tensor.collapse_shape( // CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> func.func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor {