diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.h @@ -74,6 +74,12 @@ return *this; } + bool transposeAVX512Lowering = false; + LowerVectorsOptions &setTransposeAVX512Lowering(bool opt) { + transposeAVX512Lowering = opt; + return *this; + } + bool unrollVectorTransfers = true; LowerVectorsOptions &setUnrollVectorTransfers(bool opt) { unrollVectorTransfers = opt; diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -284,7 +284,8 @@ let arguments = (ins TransformHandleTypeInterface:$target, DefaultValuedAttr:$lowering_strategy, - DefaultValuedAttr:$avx2_lowering_strategy + DefaultValuedAttr:$avx2_lowering_strategy, + DefaultValuedAttr:$avx512_lowering_strategy ); let results = (outs TransformHandleTypeInterface:$results); @@ -293,6 +294,7 @@ oilist ( `lowering_strategy` `=` $lowering_strategy | `avx2_lowering_strategy` `=` $avx2_lowering_strategy + | `avx512_lowering_strategy` `=` $avx512_lowering_strategy ) attr-dict `:` functional-type($target, results) diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -91,6 +91,34 @@ /// All intrinsics correspond 1-1 to the Intel definition. //===----------------------------------------------------------------------===// +namespace avx512 { +namespace intrin { +/// Lower to vector.shuffle v1, v2, [0, 1, 16, 17, +/// 0+4, 1+4, 16+4, 17+4, +/// 0+8, 1+8, 16+8, 17+8, +/// 0+12, 1+12, 16+12, 17+12]. +Value mm512UnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2); +/// Lower to vector.shuffle, v1, v2, [2, 3, 18, 19, +/// 2+4, 3+4, 18+4, 19+4, +/// 2+8, 3+8, 18+8, 19+8, +/// 2+12, 3+12, 18+12, 19+12]. +Value mm512UnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2); +/// Lower to vector.shuffle, v1, v2, [0, 16, 1, 17, +/// 0+4, 16+4, 1+4, 17+4, +/// 0+8, 16+8, 1+8, 17+8, +/// 0+12, 16+12, 1+12, 17+12]. +Value mm512UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2); +/// Lower to vector.shuffle, v1, v2, [2, 18, 3, 19, +/// 2+4, 18+4, 3+4, 19+4, +/// 2+8, 18+8, 3+8, 19+8, +/// 2+12, 18+12, 3+12, 19+12]. +Value mm512UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2); +} // namespace intrin + +/// 16x16xf32-specific AVX512 transpose lowering. +void transpose16x16xf32(ImplicitLocOpBuilder &ib, MutableArrayRef vs); +} // namespace avx512 + namespace avx2 { namespace inline_asm { @@ -141,6 +169,8 @@ /// 8x8xf32-specific AVX2 transpose lowering. void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef vs); +} // namespace avx2 + /// Structure to control the behavior of specialized AVX2 transpose lowering. struct TransposeLoweringOptions { bool lower4x8xf32_ = false; @@ -153,9 +183,14 @@ lower8x8xf32_ = lower; return *this; } + bool lower16x16xf32_ = false; + TransposeLoweringOptions &lower16x16xf32(bool lower = true) { + lower16x16xf32_ = lower; + return *this; + } }; -/// Options for controlling specialized AVX2 lowerings. +/// Options for controlling specialized AVX lowerings. struct LoweringOptions { /// Configure specialized vector lowerings. TransposeLoweringOptions transposeOptions; @@ -170,7 +205,6 @@ RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(), int benefit = 10); -} // namespace avx2 } // namespace x86vector /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -136,15 +136,18 @@ vector::populateVectorTransposeLoweringPatterns( patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( getLoweringStrategy())); + auto transposeOptions = x86vector::TransposeLoweringOptions(); if (getAvx2LoweringStrategy()) { - auto avx2LoweringOptions = - x86vector::avx2::LoweringOptions().setTransposeOptions( - x86vector::avx2::TransposeLoweringOptions() - .lower4x8xf32(true) - .lower8x8xf32(true)); - x86vector::avx2::populateSpecializedTransposeLoweringPatterns( - patterns, avx2LoweringOptions, /*benefit=*/10); + transposeOptions = transposeOptions.lower4x8xf32().lower8x8xf32(); } + if (getAvx512LoweringStrategy()) { + transposeOptions = transposeOptions.lower16x16xf32(); + } + + auto avxLoweringOptions = + x86vector::LoweringOptions().setTransposeOptions(transposeOptions); + x86vector::populateSpecializedTransposeLoweringPatterns( + patterns, avxLoweringOptions, /*benefit=*/10); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -27,6 +27,8 @@ using namespace mlir::x86vector::avx2; using namespace mlir::x86vector::avx2::inline_asm; using namespace mlir::x86vector::avx2::intrin; +using namespace mlir::x86vector::avx512::intrin; +using namespace mlir::x86vector::avx512; Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm( ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) { @@ -45,6 +47,149 @@ return asmOp.getResult(0); } +static SmallVector getMm512UnpackShufflePerm(ArrayRef vals) { + SmallVector res; + for (int i = 0; i < 16; i += 4) + for (int64_t v : vals) + res.push_back(v + i); + return res; +} + +Value mlir::x86vector::avx512::intrin::mm512UnpackLoPd(ImplicitLocOpBuilder &b, + Value v1, Value v2) { + return b.create(v1, v2, + getMm512UnpackShufflePerm({0, 1, 16, 17})); +} + +Value mlir::x86vector::avx512::intrin::mm512UnpackHiPd(ImplicitLocOpBuilder &b, + Value v1, Value v2) { + return b.create(v1, v2, + getMm512UnpackShufflePerm({2, 3, 18, 19})); +} + +Value mlir::x86vector::avx512::intrin::mm512UnpackLoPs(ImplicitLocOpBuilder &b, + Value v1, Value v2) { + return b.create(v1, v2, + getMm512UnpackShufflePerm({0, 16, 1, 17})); +} + +Value mlir::x86vector::avx512::intrin::mm512UnpackHiPs(ImplicitLocOpBuilder &b, + Value v1, Value v2) { + return b.create(v1, v2, + getMm512UnpackShufflePerm({2, 18, 3, 19})); +} + +Value mm512ShuffleI32x4(ImplicitLocOpBuilder &b, Value v1, Value v2, + uint8_t mask) { + assert(v1.getType().cast().getShape()[0] == 16 && + "expected a vector with length=16"); + SmallVector shuffleMask; + auto appendToMask = [&](int64_t base, uint8_t control) { + if (control == 0) + llvm::append_range(shuffleMask, ArrayRef{base + 0, base + 1, + base + 2, base + 3}); + else if (control == 1) + llvm::append_range(shuffleMask, ArrayRef{base + 4, base + 5, + base + 6, base + 7}); + else if (control == 2) + llvm::append_range(shuffleMask, ArrayRef{base + 8, base + 9, + base + 10, base + 11}); + else if (control == 3) + llvm::append_range(shuffleMask, ArrayRef{base + 12, base + 13, + base + 14, base + 15}); + else + llvm_unreachable("control > 3 : overflow"); + }; + uint8_t b01, b23, b45, b67; + MaskHelper::extractShuffle(mask, b01, b23, b45, b67); + appendToMask(0, b01); + appendToMask(0, b23); + appendToMask(16, b45); + appendToMask(16, b67); + return b.create(v1, v2, shuffleMask); +} + +void mlir::x86vector::avx512::transpose16x16xf32(ImplicitLocOpBuilder &b, + MutableArrayRef vs) { + // Interleave 32-bit lanes using + // 8x _mm512_unpacklo_epi32 + // 8x _mm512_unpackhi_epi32 + Value t0 = mm512UnpackLoPs(b, vs[0], vs[1]); + Value t1 = mm512UnpackHiPs(b, vs[0], vs[1]); + Value t2 = mm512UnpackLoPs(b, vs[2], vs[3]); + Value t3 = mm512UnpackHiPs(b, vs[2], vs[3]); + Value t4 = mm512UnpackLoPs(b, vs[4], vs[5]); + Value t5 = mm512UnpackHiPs(b, vs[4], vs[5]); + Value t6 = mm512UnpackLoPs(b, vs[6], vs[7]); + Value t7 = mm512UnpackHiPs(b, vs[6], vs[7]); + Value t8 = mm512UnpackLoPs(b, vs[8], vs[9]); + Value t9 = mm512UnpackHiPs(b, vs[8], vs[9]); + Value ta = mm512UnpackLoPs(b, vs[10], vs[11]); + Value tb = mm512UnpackHiPs(b, vs[10], vs[11]); + Value tc = mm512UnpackLoPs(b, vs[12], vs[13]); + Value td = mm512UnpackHiPs(b, vs[12], vs[13]); + Value te = mm512UnpackLoPs(b, vs[14], vs[15]); + Value tf = mm512UnpackHiPs(b, vs[14], vs[15]); + + // Interleave 64-bit lanes using + // 8x _mm512_unpacklo_epi64 + // 8x _mm512_unpackhi_epi64 + Value r0 = mm512UnpackLoPd(b, t0, t2); + Value r1 = mm512UnpackHiPd(b, t0, t2); + Value r2 = mm512UnpackLoPd(b, t1, t3); + Value r3 = mm512UnpackHiPd(b, t1, t3); + Value r4 = mm512UnpackLoPd(b, t4, t6); + Value r5 = mm512UnpackHiPd(b, t4, t6); + Value r6 = mm512UnpackLoPd(b, t5, t7); + Value r7 = mm512UnpackHiPd(b, t5, t7); + Value r8 = mm512UnpackLoPd(b, t8, ta); + Value r9 = mm512UnpackHiPd(b, t8, ta); + Value ra = mm512UnpackLoPd(b, t9, tb); + Value rb = mm512UnpackHiPd(b, t9, tb); + Value rc = mm512UnpackLoPd(b, tc, te); + Value rd = mm512UnpackHiPd(b, tc, te); + Value re = mm512UnpackLoPd(b, td, tf); + Value rf = mm512UnpackHiPd(b, td, tf); + + // Permute 128-bit lanes using + // 16x _mm512_shuffle_i32x4 + t0 = mm512ShuffleI32x4(b, r0, r4, 0x88); + t1 = mm512ShuffleI32x4(b, r1, r5, 0x88); + t2 = mm512ShuffleI32x4(b, r2, r6, 0x88); + t3 = mm512ShuffleI32x4(b, r3, r7, 0x88); + t4 = mm512ShuffleI32x4(b, r0, r4, 0xdd); + t5 = mm512ShuffleI32x4(b, r1, r5, 0xdd); + t6 = mm512ShuffleI32x4(b, r2, r6, 0xdd); + t7 = mm512ShuffleI32x4(b, r3, r7, 0xdd); + t8 = mm512ShuffleI32x4(b, r8, rc, 0x88); + t9 = mm512ShuffleI32x4(b, r9, rd, 0x88); + ta = mm512ShuffleI32x4(b, ra, re, 0x88); + tb = mm512ShuffleI32x4(b, rb, rf, 0x88); + tc = mm512ShuffleI32x4(b, r8, rc, 0xdd); + td = mm512ShuffleI32x4(b, r9, rd, 0xdd); + te = mm512ShuffleI32x4(b, ra, re, 0xdd); + tf = mm512ShuffleI32x4(b, rb, rf, 0xdd); + + // Permute 256-bit lanes using again + // 16x _mm512_shuffle_i32x4 + vs[0x0] = mm512ShuffleI32x4(b, t0, t8, 0x88); + vs[0x1] = mm512ShuffleI32x4(b, t1, t9, 0x88); + vs[0x2] = mm512ShuffleI32x4(b, t2, ta, 0x88); + vs[0x3] = mm512ShuffleI32x4(b, t3, tb, 0x88); + vs[0x4] = mm512ShuffleI32x4(b, t4, tc, 0x88); + vs[0x5] = mm512ShuffleI32x4(b, t5, td, 0x88); + vs[0x6] = mm512ShuffleI32x4(b, t6, te, 0x88); + vs[0x7] = mm512ShuffleI32x4(b, t7, tf, 0x88); + vs[0x8] = mm512ShuffleI32x4(b, t0, t8, 0xdd); + vs[0x9] = mm512ShuffleI32x4(b, t1, t9, 0xdd); + vs[0xa] = mm512ShuffleI32x4(b, t2, ta, 0xdd); + vs[0xb] = mm512ShuffleI32x4(b, t3, tb, 0xdd); + vs[0xc] = mm512ShuffleI32x4(b, t4, tc, 0xdd); + vs[0xd] = mm512ShuffleI32x4(b, t5, td, 0xdd); + vs[0xe] = mm512ShuffleI32x4(b, t6, te, 0xdd); + vs[0xf] = mm512ShuffleI32x4(b, t7, tf, 0xdd); +} + Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2) { return b.create( @@ -299,10 +444,12 @@ vs.push_back(ib.create(reshInput, i)); // Transpose set of 1-D vectors. - if (m == 4) + if (m == 4 && n == 8) transpose4x8xf32(ib, vs); - if (m == 8) + if (m == 8 && n == 8) transpose8x8xf32(ib, vs); + if (m == 16 && n == 16) + transpose16x16xf32(ib, vs); // Insert transposed 1-D vectors into the higher-order dimension of the // output vector. @@ -324,6 +471,8 @@ return applyRewrite(); if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) return applyRewrite(); + if (loweringOptions.transposeOptions.lower16x16xf32_ && m == 16 && n == 16) + return applyRewrite(); return failure(); } @@ -331,7 +480,7 @@ LoweringOptions loweringOptions; }; -void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( +void mlir::x86vector::populateSpecializedTransposeLoweringPatterns( RewritePatternSet &patterns, LoweringOptions options, int benefit) { patterns.add(options, patterns.getContext(), benefit); } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -609,3 +609,81 @@ avx2_lowering_strategy = true : (!pdl.operation) -> !pdl.operation } + +// ----- + +func.func @transpose210_1x16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<16x16x1xf32> { + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + %0 = vector.transpose %arg0, [2, 1, 0] : vector<1x16x16xf32> to vector<16x16x1xf32> + return %0 : vector<16x16x1xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + avx512_lowering_strategy = true + : (!pdl.operation) -> !pdl.operation +}