diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td @@ -18,14 +18,17 @@ // intrinsics. def VectorTransposeLowering_FlatTranspose: I32EnumAttrCase<"Flat", 1, "flat_transpose">; -// Lower 2-D transpose to `vector.shuffle`. -def VectorTransposeLowering_Shuffle: - I32EnumAttrCase<"Shuffle", 2, "shuffle">; +// Lower 2-D transpose to `vector.shuffle` on 1-D vector. +def VectorTransposeLowering_Shuffle1D: + I32EnumAttrCase<"Shuffle1D", 2, "shuffle_1d">; +// Lower 2-D transpose to `vector.shuffle` on 16x16 vector. +def VectorTransposeLowering_Shuffle16x16: + I32EnumAttrCase<"Shuffle16x16", 3, "shuffle_16x16">; def VectorTransposeLoweringAttr : I32EnumAttr< "VectorTransposeLowering", "control the lowering of `vector.transpose` operations.", [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose, - VectorTransposeLowering_Shuffle]> { + VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> { let cppNamespace = "::mlir::vector"; } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -52,6 +52,238 @@ result.append(transpose.begin(), transpose.begin() + numTransposedDims); } +/// Returns true if the option is a vector.shuffle approach. +static bool isShuffleLike(VectorTransposeLowering lowering) { + return lowering == VectorTransposeLowering::Shuffle1D || + lowering == VectorTransposeLowering::Shuffle16x16; +} + +/// Returns a shuffle mask that builds on `vals`. `numBits` must be a mutiple +/// of 128. +static SmallVector +getUnpackShufflePermFor128Lane(ArrayRef vals, int numBits) { + assert(vals.size() % 4 == 0); + int numElem = numBits / 32; + assert(numElem % 4 == 0); + SmallVector res; + for (int i = 0; i < numElem; i += 4) + for (int64_t v : vals) + res.push_back(v + i); + return res; +} + +/// Lower to vector.shuffle on v1 and v2 with UnpackLoPd shuffle mask. For +/// example, if it is targeting 512 bit vector, returns +/// vector.shuffle on v1, v2, [0, 1, 16, 17, +/// 0+4, 1+4, 16+4, 17+4, +/// 0+8, 1+8, 16+8, 17+8, +/// 0+12, 1+12, 16+12, 17+12]. +static Value createUnpackLoPd(ImplicitLocOpBuilder &b, Value v1, Value v2, + int numBits) { + int numElem = numBits / 32; + return b.create( + v1, v2, + getUnpackShufflePermFor128Lane({0, 1, numElem, numElem + 1}, numBits)); +} + +/// Lower to vector.shuffle on v1 and v2 with UnpackHiPd shuffle mask. For +/// example, if it is targeting 512 bit vector, returns +/// vector.shuffle, v1, v2, [2, 3, 18, 19, +/// 2+4, 3+4, 18+4, 19+4, +/// 2+8, 3+8, 18+8, 19+8, +/// 2+12, 3+12, 18+12, 19+12]. +static Value createUnpackHiPd(ImplicitLocOpBuilder &b, Value v1, Value v2, + int numBits) { + int numElem = numBits / 32; + return b.create( + v1, v2, + getUnpackShufflePermFor128Lane({2, 3, numElem + 2, numElem + 3}, + numBits)); +} + +/// Lower to vector.shuffle on v1 and v2 with UnpackLoPs shuffle mask. For +/// example, if it is targeting 512 bit vector, returns +/// vector.shuffle, v1, v2, [0, 16, 1, 17, +/// 0+4, 16+4, 1+4, 17+4, +/// 0+8, 16+8, 1+8, 17+8, +/// 0+12, 16+12, 1+12, 17+12]. +static Value createUnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2, + int numBits) { + int numElem = numBits / 32; + auto shuffle = b.create( + v1, v2, + getUnpackShufflePermFor128Lane({0, numElem, 1, numElem + 1}, numBits)); + return shuffle; +} + +/// Lower to vector.shuffle on v1 and v2 with UnpackHiPs shuffle mask. For +/// example, if it is targeting 512 bit vector, returns +/// vector.shuffle, v1, v2, [2, 18, 3, 19, +/// 2+4, 18+4, 3+4, 19+4, +/// 2+8, 18+8, 3+8, 19+8, +/// 2+12, 18+12, 3+12, 19+12]. +static Value createUnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2, + int numBits) { + int numElem = numBits / 32; + return b.create( + v1, v2, + getUnpackShufflePermFor128Lane({2, numElem + 2, 3, numElem + 3}, + numBits)); +} + +/// Returns a vector.shuffle that shuffles 128-bits (composed of 4 32-bit +/// elements) selected by `mask` from `v1` and `v2`. I.e., +/// +/// DEFINE SELECT4(src, control) { +/// CASE(control[1:0]) OF +/// 0: tmp[127:0] := src[127:0] +/// 1: tmp[127:0] := src[255:128] +/// 2: tmp[127:0] := src[383:256] +/// 3: tmp[127:0] := src[511:384] +/// ESAC +/// RETURN tmp[127:0] +/// } +/// dst[127:0] := SELECT4(v1[511:0], mask[1:0]) +/// dst[255:128] := SELECT4(v1[511:0], mask[3:2]) +/// dst[383:256] := SELECT4(v2[511:0], mask[5:4]) +/// dst[511:384] := SELECT4(v2[511:0], mask[7:6]) +static Value createMm512ShuffleI32x4(ImplicitLocOpBuilder &b, Value v1, + Value v2, uint8_t mask) { + assert(v1.getType().cast().getShape()[0] == 16 && + "expected a vector with length=16"); + SmallVector shuffleMask; + auto appendToMask = [&](int64_t base, uint8_t control) { + if (control == 0) + llvm::append_range(shuffleMask, ArrayRef{base + 0, base + 1, + base + 2, base + 3}); + else if (control == 1) + llvm::append_range(shuffleMask, ArrayRef{base + 4, base + 5, + base + 6, base + 7}); + else if (control == 2) + llvm::append_range(shuffleMask, ArrayRef{base + 8, base + 9, + base + 10, base + 11}); + else if (control == 3) + llvm::append_range(shuffleMask, ArrayRef{base + 12, base + 13, + base + 14, base + 15}); + else + llvm_unreachable("control > 3 : overflow"); + }; + uint8_t b01 = mask & 0x3; + uint8_t b23 = (mask >> 2) & 0x3; + uint8_t b45 = (mask >> 4) & 0x3; + uint8_t b67 = (mask >> 6) & 0x3; + appendToMask(0, b01); + appendToMask(0, b23); + appendToMask(16, b45); + appendToMask(16, b67); + return b.create(v1, v2, shuffleMask); +} + +static Value transposeToShuffle1D(OpBuilder &b, Value source, int m, int n) { + SmallVector mask; + mask.reserve(m * n); + for (int64_t j = 0; j < n; ++j) + for (int64_t i = 0; i < m; ++i) + mask.push_back(i * n + j); + + Location loc = source.getLoc(); + Value shuffled = b.create(loc, source, source, mask); + return shuffled; +} + +static Value transpose16x16(OpBuilder &builder, Value source, int m, int n) { + ImplicitLocOpBuilder b(source.getLoc(), builder); + SmallVector vs; + for (int64_t i = 0; i < m; ++i) + vs.push_back(b.create(source, i)); + + // Interleave 32-bit lanes using + // 8x _mm512_unpacklo_epi32 + // 8x _mm512_unpackhi_epi32 + Value t0 = createUnpackLoPs(b, vs[0x0], vs[0x1], 512); + Value t1 = createUnpackHiPs(b, vs[0x0], vs[0x1], 512); + Value t2 = createUnpackLoPs(b, vs[0x2], vs[0x3], 512); + Value t3 = createUnpackHiPs(b, vs[0x2], vs[0x3], 512); + Value t4 = createUnpackLoPs(b, vs[0x4], vs[0x5], 512); + Value t5 = createUnpackHiPs(b, vs[0x4], vs[0x5], 512); + Value t6 = createUnpackLoPs(b, vs[0x6], vs[0x7], 512); + Value t7 = createUnpackHiPs(b, vs[0x6], vs[0x7], 512); + Value t8 = createUnpackLoPs(b, vs[0x8], vs[0x9], 512); + Value t9 = createUnpackHiPs(b, vs[0x8], vs[0x9], 512); + Value ta = createUnpackLoPs(b, vs[0xa], vs[0xb], 512); + Value tb = createUnpackHiPs(b, vs[0xa], vs[0xb], 512); + Value tc = createUnpackLoPs(b, vs[0xc], vs[0xd], 512); + Value td = createUnpackHiPs(b, vs[0xc], vs[0xd], 512); + Value te = createUnpackLoPs(b, vs[0xe], vs[0xf], 512); + Value tf = createUnpackHiPs(b, vs[0xe], vs[0xf], 512); + + // Interleave 64-bit lanes using + // 8x _mm512_unpacklo_epi64 + // 8x _mm512_unpackhi_epi64 + Value r0 = createUnpackLoPd(b, t0, t2, 512); + Value r1 = createUnpackHiPd(b, t0, t2, 512); + Value r2 = createUnpackLoPd(b, t1, t3, 512); + Value r3 = createUnpackHiPd(b, t1, t3, 512); + Value r4 = createUnpackLoPd(b, t4, t6, 512); + Value r5 = createUnpackHiPd(b, t4, t6, 512); + Value r6 = createUnpackLoPd(b, t5, t7, 512); + Value r7 = createUnpackHiPd(b, t5, t7, 512); + Value r8 = createUnpackLoPd(b, t8, ta, 512); + Value r9 = createUnpackHiPd(b, t8, ta, 512); + Value ra = createUnpackLoPd(b, t9, tb, 512); + Value rb = createUnpackHiPd(b, t9, tb, 512); + Value rc = createUnpackLoPd(b, tc, te, 512); + Value rd = createUnpackHiPd(b, tc, te, 512); + Value re = createUnpackLoPd(b, td, tf, 512); + Value rf = createUnpackHiPd(b, td, tf, 512); + + // Permute 128-bit lanes using + // 16x _mm512_shuffle_i32x4 + t0 = createMm512ShuffleI32x4(b, r0, r4, 0x88); + t1 = createMm512ShuffleI32x4(b, r1, r5, 0x88); + t2 = createMm512ShuffleI32x4(b, r2, r6, 0x88); + t3 = createMm512ShuffleI32x4(b, r3, r7, 0x88); + t4 = createMm512ShuffleI32x4(b, r0, r4, 0xdd); + t5 = createMm512ShuffleI32x4(b, r1, r5, 0xdd); + t6 = createMm512ShuffleI32x4(b, r2, r6, 0xdd); + t7 = createMm512ShuffleI32x4(b, r3, r7, 0xdd); + t8 = createMm512ShuffleI32x4(b, r8, rc, 0x88); + t9 = createMm512ShuffleI32x4(b, r9, rd, 0x88); + ta = createMm512ShuffleI32x4(b, ra, re, 0x88); + tb = createMm512ShuffleI32x4(b, rb, rf, 0x88); + tc = createMm512ShuffleI32x4(b, r8, rc, 0xdd); + td = createMm512ShuffleI32x4(b, r9, rd, 0xdd); + te = createMm512ShuffleI32x4(b, ra, re, 0xdd); + tf = createMm512ShuffleI32x4(b, rb, rf, 0xdd); + + // Permute 256-bit lanes using again + // 16x _mm512_shuffle_i32x4 + vs[0x0] = createMm512ShuffleI32x4(b, t0, t8, 0x88); + vs[0x1] = createMm512ShuffleI32x4(b, t1, t9, 0x88); + vs[0x2] = createMm512ShuffleI32x4(b, t2, ta, 0x88); + vs[0x3] = createMm512ShuffleI32x4(b, t3, tb, 0x88); + vs[0x4] = createMm512ShuffleI32x4(b, t4, tc, 0x88); + vs[0x5] = createMm512ShuffleI32x4(b, t5, td, 0x88); + vs[0x6] = createMm512ShuffleI32x4(b, t6, te, 0x88); + vs[0x7] = createMm512ShuffleI32x4(b, t7, tf, 0x88); + vs[0x8] = createMm512ShuffleI32x4(b, t0, t8, 0xdd); + vs[0x9] = createMm512ShuffleI32x4(b, t1, t9, 0xdd); + vs[0xa] = createMm512ShuffleI32x4(b, t2, ta, 0xdd); + vs[0xb] = createMm512ShuffleI32x4(b, t3, tb, 0xdd); + vs[0xc] = createMm512ShuffleI32x4(b, t4, tc, 0xdd); + vs[0xd] = createMm512ShuffleI32x4(b, t5, td, 0xdd); + vs[0xe] = createMm512ShuffleI32x4(b, t6, te, 0xdd); + vs[0xf] = createMm512ShuffleI32x4(b, t7, tf, 0xdd); + + auto reshInputType = VectorType::get( + {m, n}, source.getType().cast().getElementType()); + Value res = + b.create(reshInputType, b.getZeroAttr(reshInputType)); + for (int64_t i = 0; i < m; ++i) + res = b.create(vs[i], res, i); + return res; +} + namespace { /// Progressive lowering of TransposeOp. /// One: @@ -84,8 +316,7 @@ for (auto attr : op.getTransp()) transp.push_back(attr.cast().getInt()); - if (vectorTransformOptions.vectorTransposeLowering == - vector::VectorTransposeLowering::Shuffle && + if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) && resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) return rewriter.notifyMatchFailure( op, "Options specifies lowering to shuffle"); @@ -145,7 +376,8 @@ vector::VectorTransformsOptions vectorTransformOptions; }; -/// Rewrite a 2-D vector.transpose as a sequence of: +/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops. +/// If the strategy is Shuffle1D, it will be lowered to: /// vector.shape_cast 2D -> 1D /// vector.shuffle /// vector.shape_cast 1D -> 2D @@ -174,24 +406,28 @@ if (transp[0] != 1 && transp[1] != 0) return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); - if (vectorTransformOptions.vectorTransposeLowering != - VectorTransposeLowering::Shuffle) - return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); - + Value res; int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); - Value casted = rewriter.create( - loc, VectorType::get({m * n}, srcType.getElementType()), - op.getVector()); - SmallVector mask; - mask.reserve(m * n); - for (int64_t j = 0; j < n; ++j) - for (int64_t i = 0; i < m; ++i) - mask.push_back(i * n + j); - - Value shuffled = - rewriter.create(loc, casted, casted, mask); + switch (vectorTransformOptions.vectorTransposeLowering) { + case VectorTransposeLowering::Shuffle1D: { + Value casted = rewriter.create( + loc, VectorType::get({m * n}, srcType.getElementType()), + op.getVector()); + res = transposeToShuffle1D(rewriter, casted, m, n); + break; + } + case VectorTransposeLowering::Shuffle16x16: + if (m != 16 || n != 16) + return failure(); + res = transpose16x16(rewriter, op.getVector(), m, n); + break; + case VectorTransposeLowering::EltWise: + case VectorTransposeLowering::Flat: + return failure(); + } + rewriter.replaceOpWithNewOp( - op, op.getResultVectorType(), shuffled); + op, op.getResultVectorType(), res); return success(); } diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -54,6 +54,6 @@ : (!pdl.operation) -> !pdl.operation %func_8 = transform.vector.lower_transpose %func_7 - lowering_strategy = "shuffle" + lowering_strategy = "shuffle_1d" : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -57,6 +57,6 @@ : (!pdl.operation) -> !pdl.operation %func_8 = transform.vector.lower_transpose %func_7 - lowering_strategy = "shuffle" + lowering_strategy = "shuffle_1d" : (!pdl.operation) -> !pdl.operation } diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -100,7 +100,7 @@ transform.sequence failures(propagate) { ^bb1(%module_op: !pdl.operation): transform.vector.lower_transpose %module_op - lowering_strategy = "shuffle" + lowering_strategy = "shuffle_1d" : (!pdl.operation) -> !pdl.operation } @@ -609,3 +609,81 @@ avx2_lowering_strategy = true : (!pdl.operation) -> !pdl.operation } + +// ----- + +func.func @transpose_16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16xf32> { + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + // CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<16x16xf32> to vector<16x16xf32> + return %0 : vector<16x16xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "shuffle_16x16" + : (!pdl.operation) -> !pdl.operation +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transpose-16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transpose-16x16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transpose-16x16.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-opt %s -convert-scf-to-cf \ +// RUN: -test-transform-dialect-interpreter \ +// RUN: -test-transform-dialect-erase-schedule \ +// RUN: -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +func.func @entry() { + %in = arith.constant dense<[[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0], [16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0], [32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0], [48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0], [64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0], [80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0], [96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0], [112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0], [128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0], [144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0], [160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, 174.0, 175.0], [176.0, 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0], [192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0], [208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0], [224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, 239.0], [240.0, 241.0, 242.0, 243.0, 244.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0]]> : vector<16x16xf32> + %0 = vector.transpose %in, [1, 0] : vector<16x16xf32> to vector<16x16xf32> + vector.print %0 : vector<16x16xf32> + // CHECK: ( ( 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240 ), + // CHECK-SAME: ( 1, 17, 33, 49, 65, 81, 97, 113, 129, 145, 161, 177, 193, 209, 225, 241 ), + // CHECK-SAME: ( 2, 18, 34, 50, 66, 82, 98, 114, 130, 146, 162, 178, 194, 210, 226, 242 ), + // CHECK-SAME: ( 3, 19, 35, 51, 67, 83, 99, 115, 131, 147, 163, 179, 195, 211, 227, 243 ), + // CHECK-SAME: ( 4, 20, 36, 52, 68, 84, 100, 116, 132, 148, 164, 180, 196, 212, 228, 244 ), + // CHECK-SAME: ( 5, 21, 37, 53, 69, 85, 101, 117, 133, 149, 165, 181, 197, 213, 229, 245 ), + // CHECK-SAME: ( 6, 22, 38, 54, 70, 86, 102, 118, 134, 150, 166, 182, 198, 214, 230, 246 ), + // CHECK-SAME: ( 7, 23, 39, 55, 71, 87, 103, 119, 135, 151, 167, 183, 199, 215, 231, 247 ), + // CHECK-SAME: ( 8, 24, 40, 56, 72, 88, 104, 120, 136, 152, 168, 184, 200, 216, 232, 248 ), + // CHECK-SAME: ( 9, 25, 41, 57, 73, 89, 105, 121, 137, 153, 169, 185, 201, 217, 233, 249 ), + // CHECK-SAME: ( 10, 26, 42, 58, 74, 90, 106, 122, 138, 154, 170, 186, 202, 218, 234, 250 ), + // CHECK-SAME: ( 11, 27, 43, 59, 75, 91, 107, 123, 139, 155, 171, 187, 203, 219, 235, 251 ), + // CHECK-SAME: ( 12, 28, 44, 60, 76, 92, 108, 124, 140, 156, 172, 188, 204, 220, 236, 252 ), + // CHECK-SAME: ( 13, 29, 45, 61, 77, 93, 109, 125, 141, 157, 173, 189, 205, 221, 237, 253 ), + // CHECK-SAME: ( 14, 30, 46, 62, 78, 94, 110, 126, 142, 158, 174, 190, 206, 222, 238, 254 ), + // CHECK-SAME: ( 15, 31, 47, 63, 79, 95, 111, 127, 143, 159, 175, 191, 207, 223, 239, 255 ) ) + return +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.lower_transpose %module_op + lowering_strategy = "shuffle_16x16" + : (!pdl.operation) -> !pdl.operation +} +