diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -186,6 +186,33 @@ vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>()); } +/// Returns true if the vector type is effectively a 2D vector (i.e., it has at +/// least two dimensions and the size of all the dimensions is one except for +/// the last two dimensions). +static bool isSupportedVectorType(VectorType vecType) { + if (vecType.getRank() < 2) + return false; + + auto shape = vecType.getShape(); + return std::all_of(shape.begin(), std ::prev(shape.end(), 2), + [](int64_t dimSize) { return dimSize == 1; }); +} + +/// Returns true if the permutation only permutes the last two dimensions. +static bool isSupportedPermutation(ArrayRef perm) { + size_t permSize = perm.size(); + if (permSize < 2) + return false; + + for (int i = 0, end = permSize - 2; i < end; ++i) { + if (perm[i] != i) + return false; + } + + return (perm[permSize - 2] == permSize - 1) && + (perm[permSize - 1] == permSize - 2); +} + /// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and /// depending on the `TransposeLoweringOptions`. class TransposeOpLowering : public OpRewritePattern { @@ -202,36 +229,49 @@ auto loc = op.getLoc(); VectorType srcType = op.getVectorType(); - if (srcType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); + if (!isSupportedVectorType(srcType)) + return rewriter.notifyMatchFailure(op, "Unsupported vector type"); SmallVector transp; for (auto attr : op.transp()) transp.push_back(attr.cast().getInt()); - if (transp[0] != 1 && transp[1] != 0) + if (!isSupportedPermutation(transp)) return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); - int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); + int64_t rank = srcType.getRank(); + int64_t m = srcType.getShape()[rank - 2], n = srcType.getShape()[rank - 1]; auto applyRewrite = [&]() { ImplicitLocOpBuilder ib(loc, rewriter); SmallVector vs; - for (int64_t i = 0; i < m; ++i) - vs.push_back(ib.create(op.vector(), i)); + for (int64_t i = 0; i < m; ++i) { + // Set all the indexes for leading one-size dimensions to zero. + SmallVector extractIdxs(rank - 1, 0); + extractIdxs.back() = i; + vs.push_back(ib.create(op.vector(), extractIdxs)); + } if (m == 4) transpose4x8xf32(ib, vs); if (m == 8) transpose8x8xf32(ib, vs); auto flattenedType = VectorType::get({n * m}, op.getVectorType().getElementType()); - auto transposedType = - VectorType::get({n, m}, op.getVectorType().getElementType()); + // Set all the leading one-size dimensions to one, if any; + SmallVector transposedTypeSizes(rank, 1); + transposedTypeSizes[rank - 2] = n; + transposedTypeSizes[rank - 1] = m; + auto transposedType = VectorType::get( + transposedTypeSizes, op.getVectorType().getElementType()); Value res = ib.create( op.getVectorType(), ib.getZeroAttr(op.getVectorType())); - // The transposed form is still 4x8 and needs to be reinterpreted as 8x4 - // via shape_casts. - for (int64_t i = 0; i < m; ++i) - res = ib.create(vs[i], res, i); + // The transposed form is still 1x...x1x4x8 and needs to be reinterpreted + // as 1x...x1x8x4 via shape_casts. + for (int64_t i = 0; i < m; ++i) { + // Set all the indexes for the leading one-size dimensions to zero. + SmallVector insertIdxs(rank - 1, 0); + insertIdxs.back() = i; + res = ib.create(vs[i], res, insertIdxs); + } if (m == 4) { res = ib.create(flattenedType, res); res = ib.create(transposedType, res); diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -70,6 +70,36 @@ return %0 : vector<8x4xf32> } +// ----- + +// AVX2-LABEL: func @transpose1x4x8 +func @transpose1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32> { + // AVX2: vector.extract {{.*}}[0, 0] + // AVX2-NEXT: vector.extract {{.*}}[0, 1] + // AVX2-NEXT: vector.extract {{.*}}[0, 2] + // AVX2-NEXT: vector.extract {{.*}}[0, 3] + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.insert {{.*}}[0, 0] + // AVX2-NEXT: vector.insert {{.*}}[0, 1] + // AVX2-NEXT: vector.insert {{.*}}[0, 2] + // AVX2-NEXT: vector.insert {{.*}}[0, 3] + // AVX2-NEXT: vector.shape_cast {{.*}} vector<1x4x8xf32> to vector<32xf32> + // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32> + %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32> + return %0 : vector<1x8x4xf32> +} + // AVX2-LABEL: func @transpose8x8 func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { // AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> @@ -94,3 +124,30 @@ %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32> return %0 : vector<8x8xf32> } + +// ----- + +// AVX2-LABEL: func @transpose1x8x8 +func @transpose1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<1x8x8xf32> { + // AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + // AVX2-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x8x8xf32> to vector<1x8x8xf32> + return %0 : vector<1x8x8xf32> +}