diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" @@ -2631,22 +2632,29 @@ targetType.getElementType()); MemRefType resultMemrefType; - if (srcType.getLayout().getAffineMap().isIdentity()) { + MemRefLayoutAttrInterface layout = srcType.getLayout(); + if (layout.isa() && layout.isIdentity()) { resultMemrefType = MemRefType::get( srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), - {}, srcType.getMemorySpaceAsInt()); + nullptr, srcType.getMemorySpace()); } else { - AffineMap map = srcType.getLayout().getAffineMap(); - int numSymbols = map.getNumSymbols(); - for (size_t i = 0; i < dimsToDrop; ++i) { - int dim = srcType.getRank() - i - 1; - map = map.replace(rewriter.getAffineDimExpr(dim), - rewriter.getAffineConstantExpr(0), - map.getNumDims() - 1, numSymbols); + MemRefLayoutAttrInterface updatedLayout; + if (auto strided = layout.dyn_cast()) { + auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop)); + updatedLayout = StridedLayoutAttr::get(strided.getContext(), strided.getOffset(), strides); + } else { + AffineMap map = srcType.getLayout().getAffineMap(); + int numSymbols = map.getNumSymbols(); + for (size_t i = 0; i < dimsToDrop; ++i) { + int dim = srcType.getRank() - i - 1; + map = map.replace(rewriter.getAffineDimExpr(dim), + rewriter.getAffineConstantExpr(0), + map.getNumDims() - 1, numSymbols); + } } resultMemrefType = MemRefType::get( srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(), - map, srcType.getMemorySpaceAsInt()); + updatedLayout, srcType.getMemorySpace()); } auto loc = readOp.getLoc(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -8,9 +8,9 @@ } // CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> // CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] -// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32 +// CHECK-SAME: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>> // CHECK: %[[VEC:.+]] = vector.transfer_read %[[SRC_0]] -// CHECK-SAME: memref<1x1x8xf32, {{.*}}>, vector<1x8xf32> +// CHECK-SAME: memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC]] // CHECK: return %[[RESULT]] @@ -34,8 +34,8 @@ func.func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 - %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>> - %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>, vector<4x1xf32> + %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, strided<[1, 1], offset: ?>> + %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, strided<[1, 1], offset: ?>>, vector<4x1xf32> return %1 : vector<4x1xf32> } // CHECK: func @contiguous_inner_most_dim_bounds(%[[SRC:.+]]: memref<1000x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1xf32> @@ -50,8 +50,8 @@ func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:index, %ii:index) -> (vector<4x1x1xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant 0.0 : f32 - %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>> - %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0 + d1 + d2)>>, vector<4x1x1xf32> + %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>> + %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>, vector<4x1x1xf32> return %1 : vector<4x1x1xf32> } // CHECK: func @contiguous_inner_most_dim_bounds_2d(%[[SRC:.+]]: memref<1000x1x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1x1xf32>