diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -2398,13 +2398,12 @@ {}, srcType.getMemorySpaceAsInt()); } else { AffineMap map = srcType.getLayout().getAffineMap(); - int numResultDims = map.getNumDims() - dimsToDrop; int numSymbols = map.getNumSymbols(); for (size_t i = 0; i < dimsToDrop; ++i) { int dim = srcType.getRank() - i - 1; map = map.replace(rewriter.getAffineDimExpr(dim), - rewriter.getAffineConstantExpr(0), numResultDims, - numSymbols); + rewriter.getAffineConstantExpr(0), + map.getNumDims() - 1, numSymbols); } resultMemrefType = MemRefType::get( srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -47,3 +47,19 @@ // CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_1]] // CHECK-SAME: {in_bounds = [true]} // CHECK-SAME: vector<4xf32> + +// ----- + +func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:index, %ii:index) -> (vector<4x1x1xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>> + %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>, vector<4x1x1xf32> + return %1 : vector<4x1x1xf32> +} +// CHECK: func @contiguous_inner_most_dim_bounds_2d(%[[SRC:.+]]: memref<1000x1x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1x1xf32> +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] +// CHECK: %[[SRC_1:.+]] = memref.subview %[[SRC_0]] +// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_1]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: vector<4xf32>