diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -250,8 +250,11 @@ auto loc = op.getLoc(); // Check if the source vector type is supported. AVX2 patterns can only be - // applied if the vector type has two dimensions greater than one. + // applied to f32 vector types with two dimensions greater than one. VectorType srcType = op.getVectorType(); + if (!srcType.getElementType().isF32()) + return rewriter.notifyMatchFailure(op, "Unsupported vector element type"); + SmallVector srcGtOneDims; for (auto &en : llvm::enumerate(srcType.getShape())) if (en.value() > 1) diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -548,6 +548,15 @@ // ----- +func @do_not_lower_nonf32_to_avx2(%arg0: vector<4x8xi32>) -> vector<8x4xi32> { + %0 = vector.transpose %arg0, [1, 0] : vector<4x8xi32> to vector<8x4xi32> + return %0 : vector<8x4xi32> +} + +// AVX2-NOT: vector.shuffle + +// ----- + // AVX2-LABEL: func @transpose021_8x1x8 func @transpose021_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<8x8x1xf32> { %0 = vector.transpose %arg0, [0, 2, 1] : vector<8x1x8xf32> to vector<8x8x1xf32>