diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -547,7 +547,15 @@ using namespace mlir::edsc::op; TransferReadOp transfer = cast(op); - if (transfer.permutation_map().isMinorIdentity()) { + + // Fall back to a loop if the fastest varying stride is not 1 or it is + // permuted. + int64_t offset; + SmallVector strides; + auto successStrides = + getStridesAndOffset(transfer.getMemRefType(), strides, offset); + if (succeeded(successStrides) && strides.back() == 1 && + transfer.permutation_map().isMinorIdentity()) { // If > 1D, emit a bunch of loops around 1-D vector transfers. if (transfer.getVectorType().getRank() > 1) return NDTransferOpHelper(rewriter, transfer, options) @@ -621,7 +629,15 @@ using namespace edsc::op; TransferWriteOp transfer = cast(op); - if (transfer.permutation_map().isMinorIdentity()) { + + // Fall back to a loop if the fastest varying stride is not 1 or it is + // permuted. + int64_t offset; + SmallVector strides; + auto successStrides = + getStridesAndOffset(transfer.getMemRefType(), strides, offset); + if (succeeded(successStrides) && strides.back() == 1 && + transfer.permutation_map().isMinorIdentity()) { // If > 1D, emit a bunch of loops around 1-D vector transfers. if (transfer.getVectorType().getRank() > 1) return NDTransferOpHelper(rewriter, transfer, options) diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir @@ -457,3 +457,28 @@ // CHECK: } // CHECK: } // CHECK: return + +// ----- + +func @transfer_read_strided(%A : memref<8x4xf32, affine_map<(d0, d1) -> (d0 + d1 * 8)>>) -> vector<4xf32> { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + %0 = vector.transfer_read %A[%c0, %c0], %f0 + : memref<8x4xf32, affine_map<(d0, d1) -> (d0 + d1 * 8)>>, vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: transfer_read_strided( +// CHECK: scf.for +// CHECK: load + +func @transfer_write_strided(%A : vector<4xf32>, %B : memref<8x4xf32, affine_map<(d0, d1) -> (d0 + d1 * 8)>>) { + %c0 = constant 0 : index + vector.transfer_write %A, %B[%c0, %c0] : + vector<4xf32>, memref<8x4xf32, affine_map<(d0, d1) -> (d0 + d1 * 8)>> + return +} + +// CHECK-LABEL: transfer_write_strided( +// CHECK: scf.for +// CHECK: store