diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1027,6 +1027,15 @@ } }; +/// Return true if the last dimension of the MemRefType has unit stride. Also +/// return true for memrefs with no strides. +static bool isLastMemrefDimUnitStride(MemRefType type) { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); +} + /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> @@ -1047,7 +1056,7 @@ // contiguous dynamic shapes in other ways than with just empty/identity // layout. auto sizes = memRefType.getShape(); - for (int index = 0, e = strides.size() - 2; index < e; ++index) { + for (int index = 0, e = strides.size() - 1; index < e; ++index) { if (ShapedType::isDynamic(sizes[index + 1]) || ShapedType::isDynamicStrideOrOffset(strides[index]) || ShapedType::isDynamicStrideOrOffset(strides[index + 1])) @@ -1149,8 +1158,7 @@ ConversionPatternRewriter &rewriter) const override { auto adaptor = getTransferOpAdapter(xferOp, operands); - if (xferOp.getVectorType().getRank() > 1 || - llvm::size(xferOp.indices()) == 0) + if (xferOp.getVectorType().getRank() > 1 || xferOp.indices().empty()) return failure(); if (xferOp.permutation_map() != AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), @@ -1160,9 +1168,8 @@ auto memRefType = xferOp.getShapedType().template dyn_cast(); if (!memRefType) return failure(); - // Only contiguous source tensors supported atm. - auto strides = computeContiguousStrides(memRefType); - if (!strides) + // Last dimension must be contiguous. (Otherwise: Use VectorToSCF.) + if (!isLastMemrefDimUnitStride(memRefType)) return failure(); // Out-of-bounds dims are handled by MaterializeTransferMask. if (xferOp.hasOutOfBoundsDim()) diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1066,7 +1066,7 @@ int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); - return succeeded(successStrides) && strides.back() == 1; + return succeeded(successStrides) && (strides.empty() || strides.back() == 1); } /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir @@ -37,6 +37,58 @@ return } +#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +#map1 = affine_map<(d0, d1) -> (6 * d0 + 2 * d1)> + +// Vector load with unit stride only on last dim. +func @transfer_read_1d_unit_stride(%A : memref) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + %fm42 = constant -42.0: f32 + scf.for %arg2 = %c1 to %c5 step %c2 { + scf.for %arg3 = %c0 to %c6 step %c3 { + %0 = memref.subview %A[%arg2, %arg3] [1, 2] [1, 1] + : memref to memref<1x2xf32, #map0> + %1 = vector.transfer_read %0[%c0, %c0], %fm42 {in_bounds=[true]} + : memref<1x2xf32, #map0>, vector<2xf32> + vector.print %1 : vector<2xf32> + } + } + return +} + +// Vector load with unit stride only on last dim. Strides are not static, so +// codegen must go through VectorToSCF 1D lowering. +func @transfer_read_1d_non_static_unit_stride(%A : memref) { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c6 = constant 6 : index + %fm42 = constant -42.0: f32 + %1 = memref.reinterpret_cast %A to offset: [%c6], sizes: [%c1, %c2], strides: [%c6, %c1] + : memref to memref + %2 = vector.transfer_read %1[%c2, %c1], %fm42 {in_bounds=[true]} + : memref, vector<4xf32> + vector.print %2 : vector<4xf32> + return +} + +// Vector load where last dim has non-unit stride. +func @transfer_read_1d_non_unit_stride(%A : memref) { + %B = memref.reinterpret_cast %A to offset: [0], sizes: [4, 3], strides: [6, 2] + : memref to memref<4x3xf32, #map1> + %c1 = constant 1 : index + %c2 = constant 2 : index + %fm42 = constant -42.0: f32 + %vec = vector.transfer_read %B[%c2, %c1], %fm42 {in_bounds=[false]} : memref<4x3xf32, #map1>, vector<3xf32> + vector.print %vec : vector<3xf32> + return +} + // Broadcast. func @transfer_read_1d_broadcast( %A : memref, %base1 : index, %base2 : index) { @@ -117,42 +169,58 @@ call @transfer_read_1d(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 ) - // 2. Write to 2D memref on first dimension. Cannot be lowered to an LLVM + // 2.a. Read 1D vector from 2D memref with non-unit stride on first dim. + call @transfer_read_1d_unit_stride(%A) : (memref) -> () + // CHECK: ( 10, 11 ) + // CHECK: ( 13, 14 ) + // CHECK: ( 30, 31 ) + // CHECK: ( 33, 34 ) + + // 2.b. Read 1D vector from 2D memref with non-unit stride on first dim. + // Strides are non-static. + call @transfer_read_1d_non_static_unit_stride(%A) : (memref) -> () + // CHECK: ( 31, 32, 33, 34 ) + + // 3. Read 1D vector from 2D memref with non-unit stride on second dim. + call @transfer_read_1d_non_unit_stride(%A) : (memref) -> () + // CHECK: ( 22, 24, -42 ) + + // 4. Write to 2D memref on first dimension. Cannot be lowered to an LLVM // vector store. Instead, generates scalar stores. call @transfer_write_1d(%A, %c3, %c2) : (memref, index, index) -> () - // 3. (Same as 1. To check if 2 works correctly.) + // 5. (Same as 1. To check if 4 works correctly.) call @transfer_read_1d(%A, %c0, %c2) : (memref, index, index) -> () // CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 ) - // 4. Read a scalar from a 2D memref and broadcast the value to a 1D vector. + // 6. Read a scalar from a 2D memref and broadcast the value to a 1D vector. // Generates a loop with vector.insertelement. call @transfer_read_1d_broadcast(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, 12, 12, 12, 12, 12, 12, 12, 12 ) - // 5. Read from 2D memref on first dimension. Accesses are in-bounds, so no + // 7. Read from 2D memref on first dimension. Accesses are in-bounds, so no // if-check is generated inside the generated loop. call @transfer_read_1d_in_bounds(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, 22, -1 ) - // 6. Optional mask attribute is specified and, in addition, there may be + // 8. Optional mask attribute is specified and, in addition, there may be // out-of-bounds accesses. call @transfer_read_1d_mask(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, -42, -1, -42, -42, -42, -42, -42, -42 ) - // 7. Same as 6, but accesses are in-bounds. + // 9. Same as 8, but accesses are in-bounds. call @transfer_read_1d_mask_in_bounds(%A, %c1, %c2) : (memref, index, index) -> () // CHECK: ( 12, -42, -1 ) - // 8. Write to 2D memref on first dimension with a mask. + // 10. Write to 2D memref on first dimension with a mask. call @transfer_write_1d_mask(%A, %c1, %c0) : (memref, index, index) -> () - // 9. (Same as 1. To check if 8 works correctly.) + // 11. (Same as 1. To check if 10 works correctly.) call @transfer_read_1d(%A, %c0, %c0) : (memref, index, index) -> () // CHECK: ( 0, -2, 20, -2, 40, -42, -42, -42, -42 )