diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -481,6 +481,10 @@ return failure(); if (xferOp.getVectorType().getRank() <= targetRank) return failure(); + // Transfer ops that modify the element type are not supported atm. + if (xferOp.getVectorType().getElementType() != + xferOp.getShapedType().getElementType()) + return failure(); return success(); } @@ -794,6 +798,10 @@ PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); + // Transfer ops that modify the element type are not supported atm. + if (xferOp.getVectorType().getElementType() != + xferOp.getShapedType().getElementType()) + return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); auto insertOp = getInsertOp(xferOp); @@ -917,6 +925,10 @@ PatternRewriter &rewriter) const override { if (xferOp.getVectorType().getRank() <= options.targetRank) return failure(); + // Transfer ops that modify the element type are not supported atm. + if (xferOp.getVectorType().getElementType() != + xferOp.getShapedType().getElementType()) + return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); auto vec = getDataVector(xferOp);