diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -339,34 +339,26 @@ } }; -/// Returns the position of the first inner dimension that has contiguous layout -/// with at least `requiredContiguousSize` contiguous elements. -/// When such a dimension is found, the return value satisfies: -/// 0 <= return_value <= memrefType.getRank() - 1. -/// When no such dimension is found, the return value is memrefType.getRank(). -static int64_t getContiguousInnerDim(MemRefType memrefType, - int64_t requiredContiguousSize) { +/// Return true if the memref type has its inner dimension matching the given +/// shape. Otherwise return false. +static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, + ArrayRef targetShape) { auto shape = memrefType.getShape(); SmallVector strides; int64_t offset; - int64_t innerDim = shape.size(); - if (succeeded(getStridesAndOffset(memrefType, strides, offset))) { - int64_t innerSize = 1; - while (true) { - if (innerDim == 0) - break; - const int64_t nextDim = innerDim - 1; - if (shape[nextDim] == ShapedType::kDynamicSize) - break; - if (strides[nextDim] != innerSize) - break; - innerSize *= shape[nextDim]; - innerDim = nextDim; - if (innerSize >= requiredContiguousSize) - break; - } + if (!succeeded(getStridesAndOffset(memrefType, strides, offset))) + return false; + if (strides.back() != 1) + return false; + strides.pop_back(); + int64_t flatDim = 1; + for (auto [targetDim, memrefDim, memrefStride] : + llvm::reverse(llvm::zip(targetShape, shape, strides))) { + flatDim *= memrefDim; + if (flatDim != memrefStride || targetDim != memrefDim) + return false; } - return innerDim; + return true; } /// Creates a memref.collapse_shape collapsing all inner dimensions of the @@ -427,10 +419,12 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - int64_t firstContiguousInnerDim = - getContiguousInnerDim(sourceType, vectorType.getNumElements()); - if (firstContiguousInnerDim >= sourceType.getRank() - 1) + if (!hasMatchingInnerContigousShape( + sourceType, + vectorType.getShape().take_back(vectorType.getRank() - 1))) return failure(); + int64_t firstContiguousInnerDim = + sourceType.getRank() - vectorType.getRank(); // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) return failure(); @@ -485,10 +479,12 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - int64_t firstContiguousInnerDim = - getContiguousInnerDim(sourceType, vectorType.getNumElements()); - if (firstContiguousInnerDim >= sourceType.getRank() - 1) + if (!hasMatchingInnerContigousShape( + sourceType, + vectorType.getShape().take_back(vectorType.getRank() - 1))) return failure(); + int64_t firstContiguousInnerDim = + sourceType.getRank() - vectorType.getRank(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) return failure(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -12,9 +12,9 @@ // CHECK-LABEL: func @transfer_read_flattenable_with_offset // CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] -// C-HECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] -// C-HECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> -// C-HECK: return %[[VEC2D]] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// CHECK: return %[[VEC2D]] // ----- @@ -26,12 +26,12 @@ return } -// C-HECK-LABEL: func @transfer_write_flattenable_with_offset -// C-HECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 -// C-HECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> -// C-HECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> -// C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> -// C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-LABEL: func @transfer_write_flattenable_with_offset +// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] // ----- @@ -104,3 +104,31 @@ // CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] // CHECK-SAME: {in_bounds = [true]} // CHECK-SAME: : vector<32xi8>, memref + +// ----- + +func.func @transfer_read_flattenable_negative( + %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8> + return %v : vector<2x2x2x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_negative +// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8> + +// ----- + +func.func @transfer_read_flattenable_negative2( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_negative2 +// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>