diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -977,6 +977,14 @@ } }; +/// Return true if the last dimension of the MemRefType has unit stride. +static bool isLastMemrefDimUnitStride(MemRefType type) { + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + return succeeded(successStrides) && strides.back() == 1; +} + /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is /// necessary in cases where a 1D vector transfer op cannot be lowered into /// vector load/stores due to non-unit strides or broadcasts: @@ -1016,11 +1024,14 @@ PatternRewriter &rewriter) const override { ScopedContext scope(rewriter, xferOp.getLoc()); auto map = xferOp.permutation_map(); + auto memRefType = xferOp.getShapedType().template dyn_cast(); + if (!memRefType) + return failure(); if (xferOp.getVectorType().getRank() != 1) return failure(); - if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM - return failure(); + if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) + return failure(); // Handled by ConvertVectorToLLVM // Loop bounds, step, state... auto vecType = xferOp.getVectorType();