diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/Bufferize.h" @@ -993,6 +994,12 @@ transposeLowering = val; return *this; } + /// Enable AVX2-specific lowerings. + bool avx2Lowering = false; + LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) { + avx2Lowering = val; + return *this; + } /// Configure the post staged-patterns late vector.transfer to scf /// conversion. @@ -1009,6 +1016,13 @@ vectorTransformOptions = options; return *this; } + /// Configure specialized vector lowerings. + x86vector::avx2::LoweringOptions avx2LoweringOptions; + LinalgVectorLoweringOptions & + setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) { + avx2LoweringOptions = options; + return *this; + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -9,13 +9,126 @@ #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H +#include "mlir/IR/Value.h" + namespace mlir { +class ImplicitLocOpBuilder; class LLVMConversionTarget; class LLVMTypeConverter; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; +namespace x86vector { + +/// Helper class to factor out the creation and extraction of masks from nibs. +struct MaskHelper { + /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. + /// Meant to be used with instructions such as mm256ShufflePs. + template + static char shuffle() { + static_assert(b01 <= 0x03, "overflow"); + static_assert(b23 <= 0x03, "overflow"); + static_assert(b45 <= 0x03, "overflow"); + static_assert(b67 <= 0x03, "overflow"); + return (b67 << 6) + (b45 << 4) + (b23 << 2) + b01; + } + /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. + static void extractShuffle(char mask, char &b01, char &b23, char &b45, + char &b67) { + b67 = (mask & (0x03 << 6)) >> 6; + b45 = (mask & (0x03 << 4)) >> 4; + b23 = (mask & (0x03 << 2)) >> 2; + b01 = mask & 0x03; + } + /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. + /// Meant to be used with instructions such as mm256Permute2f128Ps. + template + static char permute() { + static_assert(b03 <= 0x0f, "overflow"); + static_assert(b47 <= 0x0f, "overflow"); + return (b47 << 4) + b03; + } + /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. + static void extractPermute(char mask, char &b03, char &b47) { + b47 = (mask & (0x0f << 4)) >> 4; + b03 = mask & 0x0f; + } +}; + +//===----------------------------------------------------------------------===// +/// Helpers extracted from: +/// - clang/lib/Headers/avxintrin.h +/// - clang/test/CodeGen/X86/avx-builtins.c +/// - clang/test/CodeGen/X86/avx2-builtins.c +/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c +/// as well as the Intel Intrinsics Guide +/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html) +/// make it easier to just implement known good lowerings. +/// All intrinsics correspond 1-1 to the Intel definition. +//===----------------------------------------------------------------------===// + +namespace avx2 { + +/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. +Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2); + +/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. +Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2); + +/// a a b b a a b b +/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): +/// 0:127 | 128:255 +/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 +Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, char mask); + +// imm[0:1] out of imm[0:3] is: +// 0 1 2 3 +// a[0:127] or a[128:255] or b[0:127] or b[128:255] | +// a[0:127] or a[128:255] or b[0:127] or b[128:255] +// 0 1 2 3 +// imm[0:1] out of imm[4:7]. +Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, + char mask); + +/// 4x8xf32-specific AVX2 transpose lowering. +void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef vs); + +/// 8x8xf32-specific AVX2 transpose lowering. +void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef vs); + +/// Structure to control the behavior of specialized AVX2 transpose lowering. +struct TransposeLoweringOptions { + bool lower4x8xf32_ = false; + TransposeLoweringOptions &lower4x8xf32(bool lower = true) { + lower4x8xf32_ = lower; + return *this; + } + bool lower8x8xf32_ = false; + TransposeLoweringOptions &lower8x8xf32(bool lower = true) { + lower8x8xf32_ = lower; + return *this; + } +}; + +/// Options for controlling specialized AVX2 lowerings. +struct LoweringOptions { + /// Configure specialized vector lowerings. + TransposeLoweringOptions transposeOptions; + LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) { + transposeOptions = options; + return *this; + } +}; + +/// Insert specialized transpose lowering patterns. +void populateSpecializedTransposeLoweringPatterns( + RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(), + int benefit = 10); + +} // namespace avx2 +} // namespace x86vector + /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM /// intrinsics. void populateX86VectorLegalizeForLLVMExportPatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -51,5 +51,6 @@ MLIRTransforms MLIRTransformUtils MLIRVector + MLIRX86VectorTransforms MLIRVectorToSCF ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -334,6 +334,9 @@ if (options.transposeLowering) { vector::populateVectorTransposeLoweringPatterns( patterns, options.vectorTransformOptions); + if (options.avx2Lowering) + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, options.avx2LoweringOptions, /*benefit=*/10); } (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -0,0 +1,208 @@ +//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements vector.transpose rewrites as AVX patterns for particular +// sizes of interest. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; +using namespace mlir::x86vector::avx2; + +Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, + Value v2) { + return b.create( + v1, v2, ArrayRef{0, 8, 1, 9, 4, 12, 5, 13}); +} + +Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, + Value v2) { + return b.create( + v1, v2, ArrayRef{2, 10, 3, 11, 6, 14, 7, 15}); +} +/// a a b b a a b b +/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): +/// 0:127 | 128:255 +/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 +Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, + Value v2, char mask) { + char b01, b23, b45, b67; + MaskHelper::extractShuffle(mask, b01, b23, b45, b67); + SmallVector shuffleMask{b01, b23, b45 + 8, b67 + 8, + b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; + return b.create(v1, v2, shuffleMask); +} + +// imm[0:1] out of imm[0:3] is: +// 0 1 2 3 +// a[0:127] or a[128:255] or b[0:127] or b[128:255] | +// a[0:127] or a[128:255] or b[0:127] or b[128:255] +// 0 1 2 3 +// imm[0:1] out of imm[4:7]. +Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b, + Value v1, Value v2, + char mask) { + SmallVector shuffleMask; + auto appendToMask = [&](char control) { + if (control == 0) + llvm::append_range(shuffleMask, ArrayRef{0, 1, 2, 3}); + else if (control == 1) + llvm::append_range(shuffleMask, ArrayRef{4, 5, 6, 7}); + else if (control == 2) + llvm::append_range(shuffleMask, ArrayRef{8, 9, 10, 11}); + else if (control == 3) + llvm::append_range(shuffleMask, ArrayRef{12, 13, 14, 15}); + else + llvm_unreachable("control > 3 : overflow"); + }; + char b03, b47; + MaskHelper::extractPermute(mask, b03, b47); + appendToMask(b03); + appendToMask(b47); + return b.create(v1, v2, shuffleMask); +} + +/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. +void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, + MutableArrayRef vs) { + auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); +#ifndef NDEBUG + assert(vs.size() == 4 && "expects 4 vectors"); + assert(llvm::all_of(ValueRange{vs}.getTypes(), + [&](Type t) { return t == vt; }) && + "expects all types to be vector<8xf32>"); +#endif + + Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); + Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); + Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); + Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); + Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); + vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>()); + vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>()); + vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>()); + vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>()); +} + +/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. +void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, + MutableArrayRef vs) { + auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); + (void)vt; + assert(vs.size() == 8 && "expects 8 vectors"); + assert(llvm::all_of(ValueRange{vs}.getTypes(), + [&](Type t) { return t == vt; }) && + "expects all types to be vector<8xf32>"); + + Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); + Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); + Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); + Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); + Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]); + Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]); + Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]); + Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]); + Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>()); + vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>()); + vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>()); + vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>()); + vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>()); + vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>()); + vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>()); + vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>()); + vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>()); +} + +/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and +/// depending on the `TransposeLoweringOptions`. +class TransposeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, + int benefit) + : OpRewritePattern(context, benefit), + loweringOptions(loweringOptions) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType srcType = op.getVectorType(); + if (srcType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); + + SmallVector transp; + for (auto attr : op.transp()) + transp.push_back(attr.cast().getInt()); + if (transp[0] != 1 && transp[1] != 0) + return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); + + int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); + + auto applyRewrite = [&]() { + ImplicitLocOpBuilder ib(loc, rewriter); + SmallVector vs; + for (int64_t i = 0; i < m; ++i) + vs.push_back(ib.create(op.vector(), i)); + if (m == 4) + transpose4x8xf32(ib, vs); + if (m == 8) + transpose8x8xf32(ib, vs); + auto flattenedType = + VectorType::get({n * m}, op.getVectorType().getElementType()); + auto transposedType = + VectorType::get({n, m}, op.getVectorType().getElementType()); + Value res = ib.create( + op.getVectorType(), ib.getZeroAttr(op.getVectorType())); + // The transposed form is still 4x8 and needs to be reinterpreted as 8x4 + // via shape_casts. + for (int64_t i = 0; i < m; ++i) + res = ib.create(vs[i], res, i); + if (m == 4) { + res = ib.create(flattenedType, res); + res = ib.create(transposedType, res); + } + + rewriter.replaceOp(op, res); + return success(); + }; + + if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) + return applyRewrite(); + if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) + return applyRewrite(); + return failure(); + } + +private: + LoweringOptions loweringOptions; +}; + +void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + RewritePatternSet &patterns, LoweringOptions options, int benefit) { + patterns.add(options, patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRX86VectorTransforms + AVXTranspose.cpp LegalizeForLLVMExport.cpp DEPENDS @@ -10,4 +11,5 @@ MLIRIR MLIRLLVMCommonConversion MLIRLLVMIR + MLIRVector ) diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s +// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s #matvec_accesses = [ affine_map<(i, j) -> (i, j)>, diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT +// RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s +// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX +// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT +// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -149,8 +149,7 @@ // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> // CHECK: %[[R:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> -// ... bunch of extract insert to transpose B into Bt -// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[Bt:.*]] = vector.transpose %arg1, [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> // CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> // CHECK: %[[T9:.*]] = arith.mulf %[[T0]], %[[T2]] : vector<2xf32> @@ -399,28 +398,6 @@ return %0: vector<16xi32> } -// CHECK-LABEL: func @transpose23 -// CHECK-SAME: %[[A:.*]]: vector<2x3xf32> -// CHECK: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> -// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> -// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> -// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> -// CHECK: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> -// CHECK: return %[[T11]] : vector<3x2xf32> - -func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { - %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> - return %0 : vector<3x2xf32> -} - // CHECK-LABEL: func @nop_shape_cast // CHECK-SAME: %[[A:.*]]: vector<16xf32> // CHECK: return %[[A]] : vector<16xf32> diff --git a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir b/mlir/test/Dialect/Vector/vector-flat-transforms.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Vector/vector-flat-transforms.mlir +++ /dev/null @@ -1,65 +0,0 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-flat-transpose=1 | FileCheck %s - -// Tests for lowering 2-D vector.transpose into vector.flat_transpose. -// -// TODO: having ShapeCastOp2DDownCastRewritePattern and -// ShapeCastOp2DUpCastRewritePattern too early in the greedy rewriting -// patterns misses opportunities to fold shape casts! - -// No shape cast folding expected. -// -// CHECK-LABEL: func @transpose44_44( -// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> -// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> -// CHECK: %[[T9:.*]] = vector.extract_strided_slice %[[T8]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> -// -func @transpose44_44(%arg0: vector<4x4xf32>) -> vector<4x4xf32> { - %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> - return %0 : vector<4x4xf32> -} - -// Folds preceding shape cast as expected, -// no following shape cast folding expected. -// -// FIXME: PR49590 - shape_cast not stable. -// -// CHECK-LABEL: func @transpose16_44( -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> -// HECK: %[[T1:.*]] = vector.extract_strided_slice %[[T0]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xf32> to vector<4xf32> -// -func @transpose16_44(%arg0: vector<16xf32>) -> vector<4x4xf32> { - %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> - %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> - return %1 : vector<4x4xf32> -} - -// No preceding shape cast folding expected, -// but FAILS to fold following cast. -// -// CHECK-LABEL: func @transpose44_16( -// CHECK-SAME: %[[A:.*]]: vector<4x4xf32> -// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<4x4xf32> -// CHECK: %[[T8:.*]] = vector.flat_transpose %{{.*}} {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> -func @transpose44_16(%arg0: vector<4x4xf32>) -> vector<16xf32> { - %0 = vector.transpose %arg0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> - %1 = vector.shape_cast %0 : vector<4x4xf32> to vector<16xf32> - return %1 : vector<16xf32> -} - -// Folds preceding shape cast as expected, -// but FAILS to fold following cast. -// -// FIXME: PR49590 - shape_cast not stable. -// -// CHECK-LABEL: func @transpose16_16( -// CHECK-SAME: %[[A:.*]]: vector<16xf32> -// HECK: %[[T0:.*]] = vector.flat_transpose %[[A]] {columns = 4 : i32, rows = 4 : i32} : vector<16xf32> -> vector<16xf32> -// -func @transpose16_16(%arg0: vector<16xf32>) -> vector<16xf32> { - %0 = vector.shape_cast %arg0 : vector<16xf32> to vector<4x4xf32> - %1 = vector.transpose %0, [1, 0] : vector<4x4xf32> to vector<4x4xf32> - %2 = vector.shape_cast %1 : vector<4x4xf32> to vector<16xf32> - return %2 : vector<16xf32> -} diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s +// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s // CHECK-LABEL: func @maskedload0( // CHECK-SAME: %[[A0:.*]]: memref, diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s +// RUN: mlir-opt %s -test-vector-to-vector-lowering="unroll" | FileCheck %s // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt %s -test-vector-transpose-lowering=eltwise=1 | FileCheck %s --check-prefix=ELTWISE +// RUN: mlir-opt %s -test-vector-transpose-lowering=shuffle=1 | FileCheck %s --check-prefix=SHUFFLE +// RUN: mlir-opt %s -test-vector-transpose-lowering=flat=1 | FileCheck %s --check-prefix=FLAT +// RUN: mlir-opt %s -test-vector-transpose-lowering=avx2=1 | FileCheck %s --check-prefix=AVX2 + +// ELTWISE-LABEL: func @transpose23 +// ELTWISE-SAME: %[[A:.*]]: vector<2x3xf32> +// ELTWISE: %[[Z:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// ELTWISE: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2x3xf32> +// ELTWISE: %[[T1:.*]] = vector.insert %[[T0]], %[[Z]] [0, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T2:.*]] = vector.extract %[[A]][1, 0] : vector<2x3xf32> +// ELTWISE: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<3x2xf32> +// ELTWISE: %[[T4:.*]] = vector.extract %[[A]][0, 1] : vector<2x3xf32> +// ELTWISE: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [1, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T6:.*]] = vector.extract %[[A]][1, 1] : vector<2x3xf32> +// ELTWISE: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 1] : f32 into vector<3x2xf32> +// ELTWISE: %[[T8:.*]] = vector.extract %[[A]][0, 2] : vector<2x3xf32> +// ELTWISE: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [2, 0] : f32 into vector<3x2xf32> +// ELTWISE: %[[T10:.*]] = vector.extract %[[A]][1, 2] : vector<2x3xf32> +// ELTWISE: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [2, 1] : f32 into vector<3x2xf32> +// ELTWISE: return %[[T11]] : vector<3x2xf32> +func @transpose23(%arg0: vector<2x3xf32>) -> vector<3x2xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<2x3xf32> to vector<3x2xf32> + return %0 : vector<3x2xf32> +} + +// SHUFFLE-LABEL: func @transpose +// FLAT-LABEL: func @transpose( +func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { + // SHUFFLE: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> + // 0 4 + // 0 1 2 3 1 5 + // 4 5 6 7 -> 2 6 + // 3 7 + // SHUFFLE-NEXT: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> + // SHUFFLE-NEXT: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> + + // FLAT: vector.shape_cast {{.*}} : vector<2x4xf32> to vector<8xf32> + // FLAT: vector.flat_transpose %{{.*}} {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> -> vector<8xf32> + // FLAT: vector.shape_cast {{.*}} : vector<8xf32> to vector<4x2xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> + return %0 : vector<4x2xf32> +} + +// AVX2-LABEL: func @transpose4x8 +func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> { + // AVX2: vector.extract {{.*}}[0] + // AVX2-NEXT: vector.extract {{.*}}[1] + // AVX2-NEXT: vector.extract {{.*}}[2] + // AVX2-NEXT: vector.extract {{.*}}[3] + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.insert {{.*}}[0] + // AVX2-NEXT: vector.insert {{.*}}[1] + // AVX2-NEXT: vector.insert {{.*}}[2] + // AVX2-NEXT: vector.insert {{.*}}[3] + // AVX2-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32> + // AVX2-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32> + return %0 : vector<8x4xf32> +} + +// AVX2-LABEL: func @transpose8x8 +func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { + // AVX2: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + // AVX2-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32> + %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir +++ /dev/null @@ -1,14 +0,0 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s - -// CHECK-LABEL: func @transpose -func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { - // CHECK: vector.shape_cast %{{.*}} : vector<2x4xf32> to vector<8xf32> - // 0 4 - // 0 1 2 3 1 5 - // 4 5 6 7 -> 2 6 - // 3 7 - // CHECK: vector.shuffle %{{.*}} [0, 4, 1, 5, 2, 6, 3, 7] : vector<8xf32>, vector<8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<4x2xf32> - %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> - return %0 : vector<4x2xf32> -} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -1,4 +1,4 @@ -//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===// +//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,27 +11,31 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +using namespace mlir::linalg; using namespace mlir::vector; namespace { -struct TestVectorToVectorConversion - : public PassWrapper { - TestVectorToVectorConversion() = default; - TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {} +struct TestVectorToVectorLowering + : public PassWrapper { + TestVectorToVectorLowering() = default; + TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) {} StringRef getArgument() const final { - return "test-vector-to-vector-conversion"; + return "test-vector-to-vector-lowering"; } StringRef getDescription() const final { - return "Test conversion patterns between ops in the vector dialect"; + return "Test lowering patterns between ops in the vector dialect"; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -95,31 +99,22 @@ } }; -struct TestVectorContractionConversion - : public PassWrapper { +struct TestVectorContractionLowering + : public PassWrapper { StringRef getArgument() const final { - return "test-vector-contraction-conversion"; + return "test-vector-contraction-lowering"; } StringRef getDescription() const final { - return "Test conversion patterns that lower contract ops in the vector " + return "Test lowering patterns that lower contract ops in the vector " "dialect"; } - TestVectorContractionConversion() = default; - TestVectorContractionConversion(const TestVectorContractionConversion &pass) { - } + TestVectorContractionLowering() = default; + TestVectorContractionLowering(const TestVectorContractionLowering &pass) {} Option lowerToFlatMatrix{ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; - Option lowerToFlatTranspose{ - *this, "vector-flat-transpose", - llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), - llvm::cl::init(false)}; - Option lowerToShuffleTranspose{ - *this, "vector-shuffle-transpose", - llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), - llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), @@ -165,31 +160,91 @@ contractLowering = VectorContractLowering::Matmul; VectorMultiReductionLowering vectorMultiReductionLowering = VectorMultiReductionLowering::InnerParallel; - VectorTransposeLowering transposeLowering = - VectorTransposeLowering::EltWise; - if (lowerToFlatTranspose) - transposeLowering = VectorTransposeLowering::Flat; - if (lowerToShuffleTranspose) - transposeLowering = VectorTransposeLowering::Shuffle; - VectorTransformsOptions options{ - contractLowering, vectorMultiReductionLowering, transposeLowering}; + VectorTransformsOptions options{contractLowering, + vectorMultiReductionLowering, + VectorTransposeLowering()}; populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, options); populateVectorMaskOpLoweringPatterns(patterns); - if (!lowerToShuffleTranspose) - populateVectorShapeCastLoweringPatterns(patterns); - populateVectorTransposeLoweringPatterns(patterns, options); + populateVectorShapeCastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; +struct TestVectorTransposeLowering + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transpose-lowering"; + } + StringRef getDescription() const final { + return "Test lowering patterns that lower contract ops in the vector " + "dialect"; + } + TestVectorTransposeLowering() = default; + TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) {} + + Option lowerToEltwise{ + *this, "eltwise", + llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), + llvm::cl::init(false)}; + Option lowerToFlatTranspose{ + *this, "flat", + llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), + llvm::cl::init(false)}; + Option lowerToShuffleTranspose{ + *this, "shuffle", + llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), + llvm::cl::init(false)}; + Option lowerToAvx2{ + *this, "avx2", + llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), + llvm::cl::init(false)}; + + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + + // Test on one pattern in isolation. + // Explicitly disable shape_cast lowering. + LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions() + .enableVectorTransposeLowering() + .enableShapeCastLowering(false); + if (lowerToEltwise) { + options = options.setVectorTransformsOptions( + VectorTransformsOptions().setVectorTransposeLowering( + VectorTransposeLowering::EltWise)); + } + if (lowerToFlatTranspose) { + options = options.setVectorTransformsOptions( + VectorTransformsOptions().setVectorTransposeLowering( + VectorTransposeLowering::Flat)); + } + if (lowerToShuffleTranspose) { + options = options.setVectorTransformsOptions( + VectorTransformsOptions().setVectorTransposeLowering( + VectorTransposeLowering::Shuffle)); + } + if (lowerToAvx2) { + options = options.enableAVX2Lowering().setAVX2LoweringOptions( + x86vector::avx2::LoweringOptions().setTransposeOptions( + x86vector::avx2::TransposeLoweringOptions() + .lower4x8xf32() + .lower8x8xf32())); + } + + OpPassManager dynamicPM("builtin.func"); + dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options)); + if (failed(runPipeline(dynamicPM, getFunction()))) + return signalPassFailure(); + } +}; + struct TestVectorUnrollingPatterns : public PassWrapper { StringRef getArgument() const final { return "test-vector-unrolling-patterns"; } StringRef getDescription() const final { - return "Test conversion patterns to unroll contract ops in the vector " + return "Test lowering patterns to unroll contract ops in the vector " "dialect"; } TestVectorUnrollingPatterns() = default; @@ -248,7 +303,7 @@ return "test-vector-distribute-patterns"; } StringRef getDescription() const final { - return "Test conversion patterns to distribute vector ops in the vector " + return "Test lowering patterns to distribute vector ops in the vector " "dialect"; } TestVectorDistributePatterns() = default; @@ -302,7 +357,7 @@ : public PassWrapper { StringRef getArgument() const final { return "test-vector-to-forloop"; } StringRef getDescription() const final { - return "Test conversion patterns to break up a vector op into a for loop"; + return "Test lowering patterns to break up a vector op into a for loop"; } TestVectorToLoopPatterns() = default; TestVectorToLoopPatterns(const TestVectorToLoopPatterns &pass) {} @@ -365,7 +420,7 @@ return "test-vector-transfer-unrolling-patterns"; } StringRef getDescription() const final { - return "Test conversion patterns to unroll transfer ops in the vector " + return "Test lowering patterns to unroll transfer ops in the vector " "dialect"; } void runOnFunction() override { @@ -391,7 +446,7 @@ return "test-vector-transfer-full-partial-split"; } StringRef getDescription() const final { - return "Test conversion patterns to split " + return "Test lowering patterns to split " "transfer ops via scf.if + linalg ops"; } TestVectorTransferFullPartialSplitPatterns() = default; @@ -439,7 +494,7 @@ return "test-vector-transfer-lowering-patterns"; } StringRef getDescription() const final { - return "Test conversion patterns to lower transfer ops to other vector ops"; + return "Test lowering patterns to lower transfer ops to other vector ops"; } void runOnFunction() override { RewritePatternSet patterns(&getContext()); @@ -462,7 +517,7 @@ return "test-vector-multi-reduction-lowering-patterns"; } StringRef getDescription() const final { - return "Test conversion patterns to lower vector.multi_reduction to other " + return "Test lowering patterns to lower vector.multi_reduction to other " "vector ops"; } Option useOuterReductions{ @@ -495,7 +550,7 @@ } StringRef getDescription() const final { - return "Test conversion patterns that reducedes the rank of the vector " + return "Test lowering patterns that reducedes the rank of the vector " "transfer memory and vector operands."; } @@ -527,10 +582,12 @@ namespace mlir { namespace test { -void registerTestVectorConversions() { - PassRegistration(); +void registerTestVectorLowerings() { + PassRegistration(); + + PassRegistration(); - PassRegistration(); + PassRegistration(); PassRegistration(); diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -108,6 +108,7 @@ void registerTestSCFUtilsPass(); void registerTestSliceAnalysisPass(); void registerTestVectorConversions(); +void registerTestVectorLowerings(); } // namespace test } // namespace mlir @@ -198,6 +199,7 @@ mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestVectorConversions(); + mlir::test::registerTestVectorLowerings(); } #endif diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1458,6 +1458,7 @@ ":LLVMCommonConversion", ":LLVMDialect", ":StandardOps", + ":VectorOps", ":X86Vector", "//llvm:Core", "//llvm:Support", @@ -6400,6 +6401,7 @@ ":TransformUtils", ":VectorOps", ":VectorToSCF", + ":X86VectorTransforms", "//llvm:Support", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -484,6 +484,7 @@ "//mlir:Affine", "//mlir:Analysis", "//mlir:LinalgOps", + "//mlir:LinalgTransforms", "//mlir:MemRefDialect", "//mlir:Pass", "//mlir:SCFDialect",