diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1025,6 +1025,23 @@ bool hasBoundedRewriteRecursion() const final { return true; } }; +static bool isContiguous(MemRefType memRefType, + SmallVectorImpl &strides) { + int64_t offset; + auto successStrides = getStridesAndOffset(memRefType, strides, offset); + bool isContiguous = (strides.back() == 1); + if (isContiguous) { + auto sizes = memRefType.getShape(); + for (int index = 0, e = strides.size() - 2; index < e; ++index) { + if (strides[index] != strides[index + 1] * sizes[index + 1]) { + isContiguous = false; + break; + } + } + } + return succeeded(successStrides) && isContiguous; +} + class VectorTypeCastOpConversion : public ConvertToLLVMPattern { public: explicit VectorTypeCastOpConversion(MLIRContext *context, @@ -1058,22 +1075,9 @@ if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); - int64_t offset; - SmallVector strides; - auto successStrides = - getStridesAndOffset(sourceMemRefType, strides, offset); - bool isContiguous = (strides.back() == 1); - if (isContiguous) { - auto sizes = sourceMemRefType.getShape(); - for (int index = 0, e = strides.size() - 2; index < e; ++index) { - if (strides[index] != strides[index + 1] * sizes[index + 1]) { - isContiguous = false; - break; - } - } - } // Only contiguous source tensors supported atm. - if (failed(successStrides) || !isContiguous) + SmallVector strides; + if (!isContiguous(sourceMemRefType, strides)) return failure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); @@ -1141,6 +1145,10 @@ xferOp.getVectorType().getRank(), op->getContext())) return failure(); + // Only contiguous source tensors supported atm. + SmallVector strides; + if (!isContiguous(xferOp.getMemRefType(), strides)) + return failure(); auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };