diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/Bufferize.h" @@ -1018,6 +1019,12 @@ transposeLowering = val; return *this; } + /// Enable AVX2-specific lowerings. + bool avx2Lowering = false; + LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) { + avx2Lowering = val; + return *this; + } /// Configure the post staged-patterns late vector.transfer to scf /// conversion. @@ -1034,6 +1041,13 @@ vectorTransformOptions = options; return *this; } + /// Configure specialized vector lowerings. + x86vector::avx2::LoweringOptions avx2LoweringOptions; + LinalgVectorLoweringOptions & + setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) { + avx2LoweringOptions = options; + return *this; + } }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h --- a/mlir/include/mlir/Dialect/X86Vector/Transforms.h +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -9,13 +9,126 @@ #ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H #define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H +#include "mlir/IR/Value.h" + namespace mlir { +class ImplicitLocOpBuilder; class LLVMConversionTarget; class LLVMTypeConverter; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; +namespace x86vector { + +/// Helper class to factor out the creation and extraction of masks from nibs. +struct MaskHelper { + /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. + /// Meant to be used with instructions such as mm256ShufflePs. + template + static char shuffle() { + static_assert(b01 <= 0x03, "overflow"); + static_assert(b23 <= 0x03, "overflow"); + static_assert(b45 <= 0x03, "overflow"); + static_assert(b67 <= 0x03, "overflow"); + return (b67 << 6) + (b45 << 4) + (b23 << 2) + b01; + } + /// b01 captures the lower 2 bits, b67 captures the higher 2 bits. + static void extractShuffle(char mask, char &b01, char &b23, char &b45, + char &b67) { + b67 = (mask & (0x03 << 6)) >> 6; + b45 = (mask & (0x03 << 4)) >> 4; + b23 = (mask & (0x03 << 2)) >> 2; + b01 = mask & 0x03; + } + /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. + /// Meant to be used with instructions such as mm256Permute2f128Ps. + template + static char permute() { + static_assert(b03 <= 0x0f, "overflow"); + static_assert(b47 <= 0x0f, "overflow"); + return (b47 << 4) + b03; + } + /// b03 captures the lower 4 bits, b47 captures the higher 4 bits. + static void extractPermute(char mask, char &b03, char &b47) { + b47 = (mask & (0x0f << 4)) >> 4; + b03 = mask & 0x0f; + } +}; + +//===----------------------------------------------------------------------===// +/// Helpers extracted from: +/// - clang/lib/Headers/avxintrin.h +/// - clang/test/CodeGen/X86/avx-builtins.c +/// - clang/test/CodeGen/X86/avx2-builtins.c +/// - clang/test/CodeGen/X86/avx-shuffle-builtins.c +/// as well as the Intel Intrinsics Guide +/// (https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html) +/// make it easier to just implement known good lowerings. +/// All intrinsics correspond 1-1 to the Intel definition. +//===----------------------------------------------------------------------===// + +namespace avx2 { + +/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. +Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2); + +/// Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13]. +Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2); + +/// a a b b a a b b +/// Take an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): +/// 0:127 | 128:255 +/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 +Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, char mask); + +// imm[0:1] out of imm[0:3] is: +// 0 1 2 3 +// a[0:127] or a[128:255] or b[0:127] or b[128:255] | +// a[0:127] or a[128:255] or b[0:127] or b[128:255] +// 0 1 2 3 +// imm[0:1] out of imm[4:7]. +Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, + char mask); + +/// 4x8xf32-specific AVX2 transpose lowering. +void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef vs); + +/// 8x8xf32-specific AVX2 transpose lowering. +void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef vs); + +/// Structure to control the behavior of specialized avx2 transpose lowering. +struct TransposeLoweringOptions { + bool lower4x8xf32_ = false; + TransposeLoweringOptions &lower4x8xf32(bool lower = true) { + lower4x8xf32_ = lower; + return *this; + } + bool lower8x8xf32_ = false; + TransposeLoweringOptions &lower8x8xf32(bool lower = true) { + lower8x8xf32_ = lower; + return *this; + } +}; + +/// Options for controlling specialized AVX2 lowerings. +struct LoweringOptions { + /// Configure specialized vector lowerings. + TransposeLoweringOptions transposeOptions; + LoweringOptions &setTransposeOptions(TransposeLoweringOptions options) { + transposeOptions = options; + return *this; + } +}; + +/// Insert specialized transpose lowering patterns. +void populateSpecializedTransposeLoweringPatterns( + RewritePatternSet &patterns, LoweringOptions options = LoweringOptions(), + int benefit = 10); + +} // namespace avx2 +} // namespace x86vector + /// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM /// intrinsics. void populateX86VectorLegalizeForLLVMExportPatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -85,5 +85,6 @@ MLIRTransforms MLIRTransformUtils MLIRVector + MLIRX86VectorTransforms MLIRVectorToSCF ) diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -334,6 +334,9 @@ if (options.transposeLowering) { vector::populateVectorTransposeLoweringPatterns( patterns, options.vectorTransformOptions); + if (options.avx2Lowering) + x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + patterns, options.avx2LoweringOptions); } (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp @@ -0,0 +1,195 @@ +//===- AVXTranspose.cpp - Transforms from Vector to X86Vector dialects ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites as 1->N patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::x86vector; +using namespace mlir::x86vector::avx2; + +Value mlir::x86vector::avx2::mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, + Value v2) { + return b.create( + v1, v2, ArrayRef{0, 8, 1, 9, 4, 12, 5, 13}); +} + +Value mlir::x86vector::avx2::mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, + Value v2) { + return b.create( + v1, v2, ArrayRef{2, 10, 3, 11, 6, 14, 7, 15}); +} +/// a a b b a a b b +/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4): +/// 0:127 | 128:255 +/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4 +Value mlir::x86vector::avx2::mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, + Value v2, char mask) { + char b01, b23, b45, b67; + MaskHelper::extractShuffle(mask, b01, b23, b45, b67); + SmallVector shuffleMask{b01, b23, b45 + 8, b67 + 8, + b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4}; + return b.create(v1, v2, shuffleMask); +} + +// imm[0:1] out of imm[0:3] is: +// 0 1 2 3 +// a[0:127] or a[128:255] or b[0:127] or b[128:255] | +// a[0:127] or a[128:255] or b[0:127] or b[128:255] +// 0 1 2 3 +// imm[0:1] out of imm[4:7]. +Value mlir::x86vector::avx2::mm256Permute2f128Ps(ImplicitLocOpBuilder &b, + Value v1, Value v2, + char mask) { + SmallVector shuffleMask; + auto appendToMask = [&](char control) { + if (control == 0) + llvm::append_range(shuffleMask, ArrayRef{0, 1, 2, 3}); + else if (control == 1) + llvm::append_range(shuffleMask, ArrayRef{4, 5, 6, 7}); + else if (control == 2) + llvm::append_range(shuffleMask, ArrayRef{8, 9, 10, 11}); + else if (control == 3) + llvm::append_range(shuffleMask, ArrayRef{12, 13, 14, 15}); + else + llvm_unreachable("control > 3 : overflow"); + }; + char b03, b47; + MaskHelper::extractPermute(mask, b03, b47); + appendToMask(b03); + appendToMask(b47); + return b.create(v1, v2, shuffleMask); +} + +/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model. +void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib, + MutableArrayRef vs) { + auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); + (void)vt; + assert(vs.size() == 4 && "expects 4 vectors"); + assert(llvm::all_of(ValueRange{vs}.getTypes(), + [&](Type t) { return t == vt; }) && + "expects all types to be vector<8xf32>"); + + Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); + Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); + Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); + Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); + Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); + vs[0] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<2, 0>()); + vs[1] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<2, 0>()); + vs[2] = mm256Permute2f128Ps(ib, S0, S1, MaskHelper::permute<3, 1>()); + vs[3] = mm256Permute2f128Ps(ib, S2, S3, MaskHelper::permute<3, 1>()); +} + +/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model. +void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib, + MutableArrayRef vs) { + auto vt = VectorType::get({8}, Float32Type::get(ib.getContext())); + (void)vt; + assert(vs.size() == 8 && "expects 8 vectors"); + assert(llvm::all_of(ValueRange{vs}.getTypes(), + [&](Type t) { return t == vt; }) && + "expects all types to be vector<8xf32>"); + + Value T0 = mm256UnpackLoPs(ib, vs[0], vs[1]); + Value T1 = mm256UnpackHiPs(ib, vs[0], vs[1]); + Value T2 = mm256UnpackLoPs(ib, vs[2], vs[3]); + Value T3 = mm256UnpackHiPs(ib, vs[2], vs[3]); + Value T4 = mm256UnpackLoPs(ib, vs[4], vs[5]); + Value T5 = mm256UnpackHiPs(ib, vs[4], vs[5]); + Value T6 = mm256UnpackLoPs(ib, vs[6], vs[7]); + Value T7 = mm256UnpackHiPs(ib, vs[6], vs[7]); + Value S0 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S1 = mm256ShufflePs(ib, T0, T2, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S2 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S3 = mm256ShufflePs(ib, T1, T3, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S4 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S5 = mm256ShufflePs(ib, T4, T6, MaskHelper::shuffle<3, 2, 3, 2>()); + Value S6 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<1, 0, 1, 0>()); + Value S7 = mm256ShufflePs(ib, T5, T7, MaskHelper::shuffle<3, 2, 3, 2>()); + vs[0] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<2, 0>()); + vs[1] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<2, 0>()); + vs[2] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<2, 0>()); + vs[3] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<2, 0>()); + vs[4] = mm256Permute2f128Ps(ib, S0, S4, MaskHelper::permute<3, 1>()); + vs[5] = mm256Permute2f128Ps(ib, S1, S5, MaskHelper::permute<3, 1>()); + vs[6] = mm256Permute2f128Ps(ib, S2, S6, MaskHelper::permute<3, 1>()); + vs[7] = mm256Permute2f128Ps(ib, S3, S7, MaskHelper::permute<3, 1>()); +} + +/// Rewrite avx2-specific 2-D vector.transpose, for the supported cases and +/// depending on the `TransposeLoweringOptions`. +class TransposeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, + int benefit) + : OpRewritePattern(context, benefit), + loweringOptions(loweringOptions) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType srcType = op.getVectorType(); + if (srcType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Not a 2-D transpose"); + + SmallVector transp; + for (auto attr : op.transp()) + transp.push_back(attr.cast().getInt()); + if (transp[0] != 1 && transp[1] != 0) + return rewriter.notifyMatchFailure(op, "Not a 2-D transpose permutation"); + + int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); + + auto applyRewrite = [&]() { + ImplicitLocOpBuilder ib(loc, rewriter); + SmallVector vs; + for (int i = 0; i < m; ++i) + vs.push_back(ib.create(op.vector(), i)); + if (m == 4) + transpose4x8xf32(ib, vs); + if (m == 8) + transpose8x8xf32(ib, vs); + Value res = ib.create( + op.getVectorType(), ib.getZeroAttr(op.getVectorType())); + for (int i = 0; i < m; ++i) + res = ib.create(vs[i], res, i); + rewriter.replaceOp(op, res); + return success(); + }; + + if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8) + return applyRewrite(); + if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8) + return applyRewrite(); + return failure(); + } + +private: + LoweringOptions loweringOptions; +}; + +void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns( + RewritePatternSet &patterns, LoweringOptions options, int benefit) { + patterns.add(options, patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRX86VectorTransforms + AVXTranspose.cpp LegalizeForLLVMExport.cpp DEPENDS @@ -10,4 +11,5 @@ MLIRIR MLIRLLVMCommonConversion MLIRLLVMIR + MLIRVector ) diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir --- a/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-to-shuffle.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 | FileCheck %s +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-shuffle-transpose=1 +//| FileCheck %s // CHECK-LABEL: func @transpose func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { @@ -12,3 +13,9 @@ %0 = vector.transpose %arg0, [1, 0] : vector<2x4xf32> to vector<4x2xf32> return %0 : vector<4x2xf32> } + +// CHECK-LABEL: func @transpose8x8 +func @transpose8x8(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { + %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1458,6 +1458,7 @@ ":LLVMCommonConversion", ":LLVMDialect", ":StandardOps", + ":VectorOps", ":X86Vector", "//llvm:Core", "//llvm:Support", @@ -6398,6 +6399,7 @@ ":TransformUtils", ":VectorOps", ":VectorToSCF", + ":X86VectorTransforms", "//llvm:Support", ], )