diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -37,6 +37,11 @@ /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); + +/// Returns two dims that are greater than one if the transposition is applied +/// on a 2D slice. Otherwise, returns a failure. +FailureOr> isTranspose2DSlice(vector::TransposeOp op); + } // namespace vector /// Constructs a permutation map of invariant memref indices to vector diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -332,7 +332,7 @@ transp.push_back(attr.cast().getInt()); if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && - resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) + succeeded(isTranspose2DSlice(op))) return rewriter.notifyMatchFailure( op, "Options specifies lowering to shuffle"); @@ -411,36 +411,37 @@ LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); + if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering)) + return rewriter.notifyMatchFailure( + op, "not using vector shuffle based lowering"); + + auto srcGtOneDims = isTranspose2DSlice(op); + if (failed(srcGtOneDims)) + return rewriter.notifyMatchFailure( + op, "expected transposition on a 2D slice"); VectorType srcType = op.getSourceVectorType(); - if (srcType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); + int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); + int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); - SmallVector transp; - for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); - if (transp[0] != 1 && transp[1] != 0) - return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); + // Reshape the n-D input vector with only two dimensions greater than one + // to a 2-D vector. + Location loc = op.getLoc(); + auto flattenedType = VectorType::get({n * m}, srcType.getElementType()); + auto reshInputType = VectorType::get({m, n}, srcType.getElementType()); + auto reshInput = rewriter.create(loc, flattenedType, + op.getVector()); Value res; - int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); - switch (vectorTransformOptions.vectorTransposeLowering) { - case VectorTransposeLowering::Shuffle1D: { - Value casted = rewriter.create( - loc, VectorType::get({m * n}, srcType.getElementType()), - op.getVector()); - res = transposeToShuffle1D(rewriter, casted, m, n); - break; - } - case VectorTransposeLowering::Shuffle16x16: - if (m != 16 || n != 16) - return failure(); - res = transposeToShuffle16x16(rewriter, op.getVector(), m, n); - break; - case VectorTransposeLowering::EltWise: - case VectorTransposeLowering::Flat: - return failure(); + if (vectorTransformOptions.vectorTransposeLowering == + VectorTransposeLowering::Shuffle16x16 && + m == 16 && n == 16) { + reshInput = + rewriter.create(loc, reshInputType, reshInput); + res = transposeToShuffle16x16(rewriter, reshInput, m, n); + } else { + // Fallback to shuffle on 1D approach. + res = transposeToShuffle1D(rewriter, reshInput, m, n); } rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -43,6 +43,63 @@ llvm_unreachable("Expected MemRefType or TensorType"); } +/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1' +/// should be transposed with each other within the context of their 2D +/// transposition slice. +/// +/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0] +/// Return true: dim0 and dim1 are transposed within the context of their 2D +/// transposition slice ([1, 0]). +/// +/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0] +/// Return true: dim0 and dim1 are transposed within the context of their 2D +/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not* +/// transposed within the full context of the transposition. +/// +/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1] +/// Return false: dim0 and dim1 are *not* transposed within the context of +/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0) +/// and dim1 (1) are transposed within the full context of the of the +/// transposition. +static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1, + ArrayRef transp) { + // Perform a linear scan along the dimensions of the transposed pattern. If + // dim0 is found first, dim0 and dim1 are not transposed within the context of + // their 2D slice. Otherwise, 'dim1' is found first and they are transposed. + for (int64_t permDim : transp) { + if (permDim == dim0) + return false; + if (permDim == dim1) + return true; + } + + llvm_unreachable("Ill-formed transpose pattern"); +} + +FailureOr> +mlir::vector::isTranspose2DSlice(vector::TransposeOp op) { + VectorType srcType = op.getSourceVectorType(); + SmallVector srcGtOneDims; + for (auto [index, size] : llvm::enumerate(srcType.getShape())) + if (size > 1) + srcGtOneDims.push_back(index); + + if (srcGtOneDims.size() != 2) + return failure(); + + SmallVector transp; + for (auto attr : op.getTransp()) + transp.push_back(attr.cast().getInt()); + + // Check whether the two source vector dimensions that are greater than one + // must be transposed with each other so that we can apply one of the 2-D + // transpose pattens. Otherwise, these patterns are not applicable. + if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp)) + return failure(); + + return std::pair(srcGtOneDims[0], srcGtOneDims[1]); +} + /// Constructs a permutation map from memref indices to vector dimension. /// /// The implementation uses the knowledge of the mapping of enclosing loop to diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" @@ -187,39 +188,6 @@ vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); } -/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1' -/// should be transposed with each other within the context of their 2D -/// transposition slice. -/// -/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0] -/// Return true: dim0 and dim1 are transposed within the context of their 2D -/// transposition slice ([1, 0]). -/// -/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0] -/// Return true: dim0 and dim1 are transposed within the context of their 2D -/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not* -/// transposed within the full context of the transposition. -/// -/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1] -/// Return false: dim0 and dim1 are *not* transposed within the context of -/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0) -/// and dim1 (1) are transposed within the full context of the of the -/// transposition. -static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1, - ArrayRef transp) { - // Perform a linear scan along the dimensions of the transposed pattern. If - // dim0 is found first, dim0 and dim1 are not transposed within the context of - // their 2D slice. Otherwise, 'dim1' is found first and they are transposed. - for (int64_t permDim : transp) { - if (permDim == dim0) - return false; - if (permDim == dim1) - return true; - } - - llvm_unreachable("Ill-formed transpose pattern"); -} - /// Rewrite AVX2-specific vector.transpose, for the supported cases and /// depending on the `TransposeLoweringOptions`. The lowering supports 2-D /// transpose cases and n-D cases that have been decomposed into 2-D @@ -256,29 +224,16 @@ if (!srcType.getElementType().isF32()) return rewriter.notifyMatchFailure(op, "Unsupported vector element type"); - SmallVector srcGtOneDims; - for (auto [index, size] : llvm::enumerate(srcType.getShape())) - if (size > 1) - srcGtOneDims.push_back(index); - - if (srcGtOneDims.size() != 2) - return rewriter.notifyMatchFailure(op, "Unsupported vector type"); - - SmallVector transp; - for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); - - // Check whether the two source vector dimensions that are greater than one - // must be transposed with each other so that we can apply one of the 2-D - // AVX2 transpose pattens. Otherwise, these patterns are not applicable. - if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp)) + auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op); + if (failed(srcGtOneDims)) return rewriter.notifyMatchFailure( - op, "Not applicable to this transpose permutation"); + op, "expected transposition on a 2D slice"); // Retrieve the sizes of the two dimensions greater than one to be // transposed. auto srcShape = srcType.getShape(); - int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]]; + int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value())); + int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value())); auto applyRewrite = [&]() { ImplicitLocOpBuilder ib(loc, rewriter); diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -687,3 +687,82 @@ lowering_strategy = "shuffle_16x16" : (!pdl.operation) -> !pdl.operation } + +// ----- + +// CHECK-LABEL: func @transpose021_shuffle16x16xf32 +func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1x16x16xf32> { + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x16x16xf32> to vector<1x16x16xf32> + return %0 : vector<1x16x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "shuffle_16x16" + : (!pdl.operation) -> !pdl.operation +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1865,6 +1865,7 @@ ":LLVMCommonConversion", ":LLVMDialect", ":VectorDialect", + ":VectorUtils", ":X86VectorDialect", "//llvm:Core", "//llvm:Support",