diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -533,7 +533,11 @@ Option<"enableArmSVE", "enable-arm-sve", "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"enableX86Vector", "enable-x86vector", + "bool", /*default=*/"false", + "Enables the use of X86Vector dialect while lowering the vector " + "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -24,7 +24,7 @@ LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), enableArmNeon(false), enableArmSVE(false), enableAMX(false), - enableAVX512(false) {} + enableAVX512(false), enableX86Vector(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -50,6 +50,10 @@ enableAVX512 = b; return *this; } + LowerVectorToLLVMOptions &setEnableX86Vector(bool b) { + enableX86Vector = b; + return *this; + } bool reassociateFPReductions; bool enableIndexOptimizations; @@ -57,6 +61,7 @@ bool enableArmSVE; bool enableAMX; bool enableAVX512; + bool enableX86Vector; }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -23,3 +23,4 @@ add_subdirectory(Tensor) add_subdirectory(Tosa) add_subdirectory(Vector) +add_subdirectory(X86Vector) diff --git a/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(X86Vector x86vector) +add_mlir_doc(X86Vector -gen-dialect-doc X86Vector Dialects/) + +set(LLVM_TARGET_DEFINITIONS X86Vector.td) +mlir_tablegen(X86VectorConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRX86VectorConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/X86Vector/Transforms.h b/mlir/include/mlir/Dialect/X86Vector/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/Transforms.h @@ -0,0 +1,31 @@ +//===- Transforms.h - X86Vector Dialect Transformation Entrypoints -*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_TRANSFORMS_H +#define MLIR_DIALECT_X86VECTOR_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; + +/// Collect a set of patterns to lower X86Vector ops to ops that map to LLVM +/// intrinsics. +void populateX86VectorLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns); + +/// Configure the target to support lowering X86Vector ops to ops that map +/// to LLVM intrinsics. +void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_X86VECTOR_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -0,0 +1,270 @@ +//===-- X86VectorOps.td - X86Vector dialect operation defs -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the X86Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef X86VECTOR_OPS +#define X86VECTOR_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// X86Vector dialect definition +//===----------------------------------------------------------------------===// + +def X86Vector_Dialect : Dialect { + let name = "x86vector"; + let cppNamespace = "::mlir::x86vector"; +} + +//===----------------------------------------------------------------------===// +// AVX512 op definitions +//===----------------------------------------------------------------------===// + +class AVX512_Op traits = []> : + Op {} + +class AVX512_IntrOp traits = []> : + LLVM_IntrOpBase; + +// Defined by first result overload. May have to be extended for other +// instructions in the future. +class AVX512_IntrOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], + traits, /*numResults=*/1>; +//----------------------------------------------------------------------------// +// MaskCompressOp +//----------------------------------------------------------------------------// + +def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect, + // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could + // then be removed from assemblyFormat. + AllTypesMatch<["a", "dst"]>, + TypesMatchWith<"`k` has the same number of bits as elements in `dst`", + "dst", "k", + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">]> { + let summary = "Masked compress op"; + let description = [{ + The mask.compress op is an AVX512 specific op that can lower to the + `llvm.mask.compress` instruction. Instead of `src`, a constant vector + vector attribute `constant_src` may be specified. If neither `src` nor + `constant_src` is specified, the remaining elements in the result vector are + set to zero. + + #### From the Intel Intrinsics Guide: + + Contiguously store the active integer/floating-point elements in `a` (those + with their respective bit set in writemask `k`) to `dst`, and pass through the + remaining elements from `src`. + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins VectorOfLengthAndType<[16, 8], + [I1]>:$k, + VectorOfLengthAndType<[16, 8], + [F32, I32, F64, I64]>:$a, + Optional>:$src, + OptionalAttr:$constant_src); + let results = (outs VectorOfLengthAndType<[16, 8], + [F32, I32, F64, I64]>:$dst); + let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict" + " `:` type($dst) (`,` type($src)^)?"; +} + +def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ + NoSideEffect, + AllTypesMatch<["a", "src", "res"]>, + TypesMatchWith<"`k` has the same number of bits as elements in `res`", + "res", "k", + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">]> { + let arguments = (ins VectorOfLengthAndType<[16, 8], + [F32, I32, F64, I64]>:$a, + VectorOfLengthAndType<[16, 8], + [F32, I32, F64, I64]>:$src, + VectorOfLengthAndType<[16, 8], + [I1]>:$k); +} + +//----------------------------------------------------------------------------// +// MaskRndScaleOp +//----------------------------------------------------------------------------// + +def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect, + AllTypesMatch<["src", "a", "dst"]>, + TypesMatchWith<"imm has the same number of bits as elements in dst", + "dst", "imm", + "IntegerType::get($_self.getContext(), " + "($_self.cast().getShape()[0]))">]> { + let summary = "Masked roundscale op"; + let description = [{ + The mask.rndscale op is an AVX512 specific op that can lower to the proper + LLVMAVX512 operation: `llvm.mask.rndscale.ps.512` or + `llvm.mask.rndscale.pd.512` instruction depending on the type of vectors it + is applied to. + + #### From the Intel Intrinsics Guide: + + Round packed floating-point elements in `a` to the number of fraction bits + specified by `imm`, and store the results in `dst` using writemask `k` + (elements are copied from src when the corresponding mask bit is not set). + }]; + // Supports vector<16xf32> and vector<8xf64>. + let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, + I32:$k, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, + AnyTypeOf<[I16, I8]>:$imm, + // TODO: figure rounding out (optional operand?). + I32:$rounding + ); + let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); + let assemblyFormat = + "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)"; +} + +def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, + I32:$k, + VectorOfLengthAndType<[16], [F32]>:$a, + I16:$imm, + LLVM_Type:$rounding); +} + +def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, + I32:$k, + VectorOfLengthAndType<[8], [F64]>:$a, + I8:$imm, + LLVM_Type:$rounding); +} + +//----------------------------------------------------------------------------// +// MaskScaleFOp +//----------------------------------------------------------------------------// + +def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect, + AllTypesMatch<["src", "a", "b", "dst"]>, + TypesMatchWith<"k has the same number of bits as elements in dst", + "dst", "k", + "IntegerType::get($_self.getContext(), " + "($_self.cast().getShape()[0]))">]> { + let summary = "ScaleF op"; + let description = [{ + The `mask.scalef` op is an AVX512 specific op that can lower to the proper + LLVMAVX512 operation: `llvm.mask.scalef.ps.512` or + `llvm.mask.scalef.pd.512` depending on the type of MLIR vectors it is + applied to. + + #### From the Intel Intrinsics Guide: + + Scale the packed floating-point elements in `a` using values from `b`, and + store the results in `dst` using writemask `k` (elements are copied from src + when the corresponding mask bit is not set). + }]; + // Supports vector<16xf32> and vector<8xf64>. + let arguments = (ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$b, + AnyTypeOf<[I16, I8]>:$k, + // TODO: figure rounding out (optional operand?). + I32:$rounding + ); + let results = (outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst); + // Fully specified by traits. + let assemblyFormat = + "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)"; +} + +def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "b", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, + VectorOfLengthAndType<[16], [F32]>:$a, + VectorOfLengthAndType<[16], [F32]>:$b, + I16:$k, + LLVM_Type:$rounding); +} + +def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "b", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, + VectorOfLengthAndType<[8], [F64]>:$a, + VectorOfLengthAndType<[8], [F64]>:$b, + I8:$k, + LLVM_Type:$rounding); +} + +//----------------------------------------------------------------------------// +// Vp2IntersectOp +//----------------------------------------------------------------------------// + +def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect, + AllTypesMatch<["a", "b"]>, + TypesMatchWith<"k1 has the same number of bits as elements in a", + "a", "k1", + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">, + TypesMatchWith<"k2 has the same number of bits as elements in b", + // Should use `b` instead of `a`, but that would require + // adding `type($b)` to assemblyFormat. + "a", "k2", + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">]> { + let summary = "Vp2Intersect op"; + let description = [{ + The `vp2intersect` op is an AVX512 specific op that can lower to the proper + LLVMAVX512 operation: `llvm.vp2intersect.d.512` or + `llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is + applied to. + + #### From the Intel Intrinsics Guide: + + Compute intersection of packed integer vectors `a` and `b`, and store + indication of match in the corresponding bit of two mask registers + specified by `k1` and `k2`. A match in corresponding elements of `a` and + `b` is indicated by a set bit in the corresponding bit of the mask + registers. + }]; + let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a, + VectorOfLengthAndType<[16, 8], [I32, I64]>:$b + ); + let results = (outs VectorOfLengthAndType<[16, 8], [I1]>:$k1, + VectorOfLengthAndType<[16, 8], [I1]>:$k2 + ); + let assemblyFormat = + "$a `,` $b attr-dict `:` type($a)"; +} + +def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [ + NoSideEffect]> { + let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a, + VectorOfLengthAndType<[16], [I32]>:$b); +} + +def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [ + NoSideEffect]> { + let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a, + VectorOfLengthAndType<[8], [I64]>:$b); +} + +#endif // X86VECTOR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h @@ -0,0 +1,27 @@ +//===- X86VectorDialect.h - MLIR Dialect for X86Vector ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for X86Vector in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ +#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/X86Vector/X86VectorDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/X86Vector.h.inc" + +#endif // MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -43,6 +43,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Dialect.h" namespace mlir { @@ -78,7 +79,8 @@ SDBMDialect, shape::ShapeDialect, tensor::TensorDialect, - tosa::TosaDialect>(); + tosa::TosaDialect, + x86vector::X86VectorDialect>(); // clang-format on } diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -22,6 +22,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" namespace mlir { class DialectRegistry; @@ -37,6 +38,7 @@ registerNVVMDialectTranslation(registry); registerOpenMPDialectTranslation(registry); registerROCDLDialectTranslation(registry); + registerX86VectorDialectTranslation(registry); } } // namespace mlir diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h @@ -0,0 +1,32 @@ +//===- X86VectorToLLVMIRTranslation.h - X86Vector to LLVM IR ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for X86Vector dialect to LLVM IR +// translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the X86Vector dialect and the translation from it to the LLVM IR +/// in the given registry; +void registerX86VectorDialectTranslation(DialectRegistry ®istry); + +/// Register the X86Vector dialect and the translation from it in the registry +/// associated with the given context. +void registerX86VectorDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_X86VECTOR_X86VECTORTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -27,4 +27,6 @@ MLIRTargetLLVMIRExport MLIRTransforms MLIRVector + MLIRX86Vector + MLIRX86VectorTransforms ) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -24,6 +24,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -39,6 +41,7 @@ this->enableArmSVE = options.enableArmSVE; this->enableAMX = options.enableAMX; this->enableAVX512 = options.enableAVX512; + this->enableX86Vector = options.enableX86Vector; } // Override explicitly to allow conditional dialect dependence. void getDependentDialects(DialectRegistry ®istry) const override { @@ -52,6 +55,8 @@ registry.insert(); if (enableAVX512) registry.insert(); + if (enableX86Vector) + registry.insert(); } void runOnOperation() override; }; @@ -118,6 +123,10 @@ configureAVX512LegalizeForExportTarget(target); populateAVX512LegalizeForLLVMExportPatterns(converter, patterns); } + if (enableX86Vector) { + configureX86VectorLegalizeForExportTarget(target); + populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); + } if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -25,6 +25,7 @@ add_subdirectory(Tosa) add_subdirectory(Utils) add_subdirectory(Vector) +add_subdirectory(X86Vector) set(LLVM_OPTIONAL_SOURCES Traits.cpp diff --git a/mlir/lib/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRX86Vector + X86VectorDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/X86Vector + + DEPENDS + MLIRX86VectorIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -0,0 +1,45 @@ +//===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the X86Vector dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void x86vector::X86VectorDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" + >(); +} + +static LogicalResult verify(x86vector::MaskCompressOp op) { + if (op.src() && op.constant_src()) + return emitError(op.getLoc(), "cannot use both src and constant_src"); + + if (op.src() && (op.src().getType() != op.dst().getType())) + return emitError(op.getLoc(), + "failed to verify that src and dst have same type"); + + if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType())) + return emitError( + op.getLoc(), + "failed to verify that constant_src and dst have same type"); + + return success(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc" diff --git a/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRX86VectorTransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRX86VectorConversionsIncGen + + LINK_LIBS PUBLIC + MLIRX86Vector + MLIRIR + MLIRLLVMIR + MLIRStandardToLLVM + ) diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,142 @@ +//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation +//----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/X86Vector/Transforms.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::x86vector; + +/// Extracts the "main" vector element type from the given X86Vector operation. +template +static Type getSrcVectorElementType(OpTy op) { + return op.src().getType().template cast().getElementType(); +} +template <> +Type getSrcVectorElementType(Vp2IntersectOp op) { + return op.a().getType().template cast().getElementType(); +} + +namespace { +/// Base conversion for AVX512 ops that can be lowered to one of the two +/// intrinsics based on the bitwidth of their "main" vector element type. This +/// relies on the to-LLVM-dialect conversion helpers to correctly pack the +/// results of multi-result intrinsic ops. +template +struct AVX512LowerToIntrinsic : public OpConversionPattern { + explicit AVX512LowerToIntrinsic(LLVMTypeConverter &converter) + : OpConversionPattern(converter, &converter.getContext()) {} + + LLVMTypeConverter &getTypeConverter() const { + return *static_cast( + OpConversionPattern::getTypeConverter()); + } + + LogicalResult + matchAndRewrite(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getSrcVectorElementType(op); + unsigned bitwidth = elementType.getIntOrFloatBitWidth(); + if (bitwidth == 32) + return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), + operands, getTypeConverter(), + rewriter); + if (bitwidth == 64) + return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), + operands, getTypeConverter(), + rewriter); + return rewriter.notifyMatchFailure( + op, "expected 'src' to be either f32 or f64"); + } +}; + +struct MaskCompressOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MaskCompressOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MaskCompressOp::Adaptor adaptor(operands); + auto opType = adaptor.a().getType(); + + Value src; + if (op.src()) { + src = adaptor.src(); + } else if (op.constant_src()) { + src = rewriter.create(op.getLoc(), opType, + op.constant_srcAttr()); + } else { + Attribute zeroAttr = rewriter.getZeroAttr(opType); + src = rewriter.create(op->getLoc(), opType, zeroAttr); + } + + rewriter.replaceOpWithNewOp(op, opType, adaptor.a(), + src, adaptor.k()); + + return success(); + } +}; + +/// An entry associating the "main" AVX512 op with its instantiations for +/// vectors of 32-bit and 64-bit elements. +template +struct AVX512RegEntry { + using MainOp = OpTy; + using Intr32Op = Intr32OpTy; + using Intr64Op = Intr64OpTy; +}; + +/// A container for op association entries facilitating the configuration of +/// dialect conversion. +template +struct RegistryImpl { + /// Registers the patterns specializing the "main" op to one of the + /// "intrinsic" ops depending on elemental type. + static void registerPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add< + AVX512LowerToIntrinsic...>(converter); + } + + /// Configures the conversion target to lower out "main" ops. + static void configureTarget(LLVMConversionTarget &target) { + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + } +}; + +using Registry = RegistryImpl< + AVX512RegEntry, + AVX512RegEntry, + AVX512RegEntry>; + +} // namespace + +/// Populate the given list with patterns that convert from X86Vector to LLVM. +void mlir::populateX86VectorLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + Registry::registerPatterns(converter, patterns); + patterns.add(converter); +} + +void mlir::configureX86VectorLegalizeForExportTarget( + LLVMConversionTarget &target) { + Registry::configureTarget(target); + target.addLegalOp(); + target.addIllegalOp(); +} diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -44,6 +44,7 @@ MLIRNVVMToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation + MLIRX86VectorToLLVMIRTranslation ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(NVVM) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRX86VectorToLLVMIRTranslation + X86VectorToLLVMIRTranslation.cpp + + DEPENDS + MLIRX86VectorConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRX86Vector + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.cpp @@ -0,0 +1,56 @@ +//===- X86VectorToLLVMIRTranslation.cpp - Translate X86Vector to LLVM IR --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR X86Vector dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsX86.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the X86Vector dialect to LLVM IR. +class X86VectorDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/X86Vector/X86VectorConversions.inc" + + return failure(); + } +}; +} // end namespace + +void mlir::registerX86VectorDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerX86VectorDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerX86VectorDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-x86vector" | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @avx512_mask_rndscale +func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>) +{ + // CHECK: x86vector.avx512.intr.mask.rndscale.ps.512 + %0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32> + // CHECK: x86vector.avx512.intr.mask.rndscale.pd.512 + %1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64> + + // CHECK: x86vector.avx512.intr.mask.scalef.ps.512 + %2 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> + // CHECK: x86vector.avx512.intr.mask.scalef.pd.512 + %3 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> + + // Keep results alive. + return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> +} + +// CHECK-LABEL: func @avx512_mask_compress +func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>, + %k2: vector<8xi1>, %a2: vector<8xi64>) + -> (vector<16xf32>, vector<16xf32>, vector<8xi64>) +{ + // CHECK: x86vector.avx512.intr.mask.compress + %0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32> + // CHECK: x86vector.avx512.intr.mask.compress + %1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + // CHECK: x86vector.avx512.intr.mask.compress + %2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> + return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> +} + +// CHECK-LABEL: func @avx512_vp2intersect +func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) + -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) +{ + // CHECK: x86vector.avx512.intr.vp2intersect.d.512 + %0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32> + // CHECK: x86vector.avx512.intr.vp2intersect.q.512 + %2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64> + return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> +} diff --git a/mlir/test/Dialect/X86Vector/roundtrip.mlir b/mlir/test/Dialect/X86Vector/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/X86Vector/roundtrip.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @avx512_mask_rndscale +func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<16xf32> + %0 = x86vector.avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32> + // CHECK: x86vector.avx512.mask.rndscale {{.*}}: vector<8xf64> + %1 = x86vector.avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64> + return %0, %1: vector<16xf32>, vector<8xf64> +} + +// CHECK-LABEL: func @avx512_scalef +func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<16xf32> + %0 = x86vector.avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> + // CHECK: x86vector.avx512.mask.scalef {{.*}}: vector<8xf64> + %1 = x86vector.avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64> + return %0, %1: vector<16xf32>, vector<8xf64> +} + +// CHECK-LABEL: func @avx512_vp2intersect +func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) + -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) +{ + // CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<16xi32> + %0, %1 = x86vector.avx512.vp2intersect %a, %a : vector<16xi32> + // CHECK: x86vector.avx512.vp2intersect {{.*}} : vector<8xi64> + %2, %3 = x86vector.avx512.vp2intersect %b, %b : vector<8xi64> + return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> +} + +// CHECK-LABEL: func @avx512_mask_compress +func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>, + %k2: vector<8xi1>, %a2: vector<8xi64>) + -> (vector<16xf32>, vector<16xf32>, vector<8xi64>) +{ + // CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32> + %0 = x86vector.avx512.mask.compress %k1, %a1 : vector<16xf32> + // CHECK: x86vector.avx512.mask.compress {{.*}} : vector<16xf32> + %1 = x86vector.avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + // CHECK: x86vector.avx512.mask.compress {{.*}} : vector<8xi64> + %2 = x86vector.avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> + return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/lit.local.cfg @@ -0,0 +1,15 @@ +import sys + +# X86Vector tests must be enabled via build flag. +if config.mlir_run_x86vector_tests != 'ON': + config.unsupported = True + +# No JIT on win32. +if sys.platform == 'win32': + config.unsupported = True + +if config.intel_sde_executable: + # Run test in emulator (Intel SDE). + config.substitutions.append(('%lli', config.intel_sde_executable + ' -tgl -- lli')) +else: + config.substitutions.append(('%lli', 'lli')) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-mask-compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-mask-compress.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-mask-compress.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() -> i32 { + %i0 = constant 0 : i32 + + %a = std.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32> + %k = std.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1> + %r1 = x86vector.avx512.mask.compress %k, %a : vector<16xf32> + %r2 = x86vector.avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + + vector.print %r1 : vector<16xf32> + // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 ) + + vector.print %r2 : vector<16xf32> + // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 ) + + %src = std.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32> + %r3 = x86vector.avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32> + + vector.print %r3 : vector<16xf32> + // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 ) + + return %i0 : i32 +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-sparse-dot-product.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-sparse-dot-product.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-sparse-dot-product.mlir @@ -0,0 +1,477 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// This test shows how to implement a sparse vector-vector dot product with +// AVX512. It uses vp2intersect, mask.compress and vector.contract to compute +// the dot product of two sparse HW vectors of 8 float64 elements ("segment"). +// Each sparse vector is represented by an index memref (A or C) and by a data +// memref (B or D), containing M or N elements. +// +// There are four different implementations: +// * `memref_dot_simple`: Simple O(N*M) implementation with two for loops. +// * `memref_dot_optimized`: An optimized O(N*M) version of the previous +// implementation, where the second for loop skips over some elements. +// * `memref_dot_while`: An optimized O(N+M) implementation that utilizes a +// single while loop, coiterating over both vectors. +// * `memref_dot_while_branchless`: An optimized O(N+M) implementation that +// consists of a single while loop and has no branches within the loop. +// +// Output of llvm-mca: +// https://gist.github.com/matthias-springer/72e7ee1b3c467e7aefb6e1fd862e4841 + +#contraction_accesses = [ + affine_map<(i) -> (i)>, + affine_map<(i) -> (i)>, + affine_map<(i) -> ()> +] +#contraction_trait = { + indexing_maps = #contraction_accesses, + iterator_types = ["reduction"] +} + +// Sparse vector dot product of two vectors. +func @vector_dot(%v_A : vector<8xi64>, %v_B : vector<8xf64>, + %v_C : vector<8xi64>, %v_D : vector<8xf64>) -> f64 { + // Compute intersection of indices. + %k0, %k1 = x86vector.avx512.vp2intersect %v_A, %v_C : vector<8xi64> + + // Filter out values without match and compress vector. + %p0 = x86vector.avx512.mask.compress %k0, %v_B : vector<8xf64> + %p1 = x86vector.avx512.mask.compress %k1, %v_D : vector<8xf64> + + // Dense vector dot product. + %acc = std.constant 0.0 : f64 + %r = vector.contract #contraction_trait %p0, %p1, %acc + : vector<8xf64>, vector<8xf64> into f64 + + return %r : f64 +} + +// Fill input memrefs will all zeros, so that they can be used with arbitrary +// input sizes up to 128 elements per sparse vector. +func @init_input(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref) { + %c0 = constant 0 : index + %v_data = constant dense<0.0> : vector<128xf64> + %v_index = constant dense<9223372036854775807> : vector<128xi64> + + vector.transfer_write %v_index, %m_A[%c0] : vector<128xi64>, memref + vector.transfer_write %v_data, %m_B[%c0] : vector<128xf64>, memref + vector.transfer_write %v_index, %m_C[%c0] : vector<128xi64>, memref + vector.transfer_write %v_data, %m_D[%c0] : vector<128xf64>, memref + + return +} + +func @fill_input_1(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref) + -> (index, index){ + call @init_input(%m_A, %m_B, %m_C, %m_D) + : (memref, memref, memref, memref) -> () + + %c0 = constant 0 : index + + %v_A = std.constant dense<[0, 1, 10, 12, 13, 17, 18, 21, + 51, 52, 57, 61, 62, 82, 98, 99]> : vector<16xi64> + %v_B = std.constant dense<[1., 5., 8., 3., 2., 1., 0., 9., + 6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64> + %v_C = std.constant dense<[1, 2, 5, 10, 11, 12, 47, 48, + 67, 68, 69, 70, 71, 72, 77, 78, + 79, 82, 83, 84, 85, 90, 91, 98]> : vector<24xi64> + %v_D = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., + 6., 7., 7., 3., 5., 2., 9., 1., + 2., 9., 8., 7., 2., 0., 0., 4.]> : vector<24xf64> + + vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref + vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref + vector.transfer_write %v_C, %m_C[%c0] : vector<24xi64>, memref + vector.transfer_write %v_D, %m_D[%c0] : vector<24xf64>, memref + + %M = std.constant 16 : index + %N = std.constant 24 : index + + return %M, %N : index, index +} + +func @fill_input_2(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref) + -> (index, index){ + call @init_input(%m_A, %m_B, %m_C, %m_D) + : (memref, memref, memref, memref) -> () + + %c0 = constant 0 : index + + %v_A = std.constant dense<[0, 1, 3, 5, 6, 7, 8, 9, + 51, 52, 57, 61, 62, 63, 65, 66]> : vector<16xi64> + %v_B = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., + 6., 7., 7., 3., 5., 2., 9., 1.]> : vector<16xf64> + %v_C = std.constant dense<[6, 7, 11, 12, 15, 17, 19, 21, + 30, 31, 33, 34, 37, 39, 40, 41, + 42, 44, 45, 46, 47, 48, 49, 50, + 62, 63, 64, 65, 66, 67, 68, 69, + 70, 77, 78, 79, 81, 82, 89, 99]> : vector<40xi64> + %v_D = std.constant dense<[1., 5., 8., 3., 2., 1., 2., 9., + 6., 7., 7., 3., 5., 2., 9., 1., + 2., 9., 8., 7., 2., 1., 2., 4., + 4., 5., 8., 8., 2., 3., 5., 1., + 8., 6., 6., 4., 3., 8., 9., 2.]> : vector<40xf64> + + vector.transfer_write %v_A, %m_A[%c0] : vector<16xi64>, memref + vector.transfer_write %v_B, %m_B[%c0] : vector<16xf64>, memref + vector.transfer_write %v_C, %m_C[%c0] : vector<40xi64>, memref + vector.transfer_write %v_D, %m_D[%c0] : vector<40xf64>, memref + + %M = std.constant 16 : index + %N = std.constant 40 : index + + return %M, %N : index, index +} + +// Simple vector dot product implementation: Intersect every segment of size 8 +// in (%m_A, %m_B) with every segment of size 8 in (%m_C, %m_D). +func @memref_dot_simple(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref, + %M : index, %N : index) + -> f64 { + // Helper constants for loops. + %c0 = constant 0 : index + %c8 = constant 8 : index + + %data_zero = constant 0.0 : f64 + %index_padding = constant 9223372036854775807 : i64 + + // Notation: %sum is the current (partial) aggregated dot product sum. + + %r0 = scf.for %a = %c0 to %M step %c8 + iter_args(%sum0 = %data_zero) -> (f64) { + %v_A = vector.transfer_read %m_A[%a], %index_padding + : memref, vector<8xi64> + %v_B = vector.transfer_read %m_B[%a], %data_zero + : memref, vector<8xf64> + + %r1 = scf.for %b = %c0 to %N step %c8 + iter_args(%sum1 = %sum0) -> (f64) { + %v_C = vector.transfer_read %m_C[%b], %index_padding + : memref, vector<8xi64> + %v_D = vector.transfer_read %m_D[%b], %data_zero + : memref, vector<8xf64> + + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) -> f64 + %r2 = addf %sum1, %subresult : f64 + scf.yield %r2 : f64 + } + + scf.yield %r1 : f64 + } + + return %r0 : f64 +} + +// Optimized vector dot product implementation: Taking advantage of the fact +// that indices in %m_A and %m_C are sorted ascendingly, skip over segments +// in (%m_C, %m_D) that are know to have no intersection with the current +// segment from (%m_A, %m_B). +func @memref_dot_optimized(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref, + %M : index, %N : index) + -> f64 { + // Helper constants for loops. + %c0 = constant 0 : index + %i0 = constant 0 : i32 + %i7 = constant 7 : i32 + %c8 = constant 8 : index + + %data_zero = constant 0.0 : f64 + %index_padding = constant 9223372036854775807 : i64 + + // Notation: %sum is the current (partial) aggregated dot product sum. + // %j_start is the value from which the inner for loop starts iterating. This + // value keeps increasing if earlier segments of (%m_C, %m_D) are known to + // be no longer needed. + + %r0, %t0 = scf.for %a = %c0 to %M step %c8 + iter_args(%sum0 = %data_zero, %b_start0 = %c0) -> (f64, index) { + %v_A = vector.transfer_read %m_A[%a], %index_padding + : memref, vector<8xi64> + %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + + %r1, %next_b_start0 = scf.for %b = %b_start0 to %N step %c8 + iter_args(%sum1 = %sum0, %b_start1 = %b_start0) -> (f64, index) { + %v_C = vector.transfer_read %m_C[%b], %index_padding + : memref, vector<8xi64> + %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + %seg1_done = cmpi "slt", %segB_max, %segA_min : i64 + + %r2, %next_b_start1 = scf.if %seg1_done -> (f64, index) { + // %v_C segment is done, no need to examine this one again (ever). + %next_b_start2 = addi %b_start1, %c8 : index + scf.yield %sum1, %next_b_start2 : f64, index + } else { + %v_B = vector.transfer_read %m_B[%a], %data_zero + : memref, vector<8xf64> + %v_D = vector.transfer_read %m_D[%b], %data_zero + : memref, vector<8xf64> + + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) + -> f64 + %r3 = addf %sum1, %subresult : f64 + scf.yield %r3, %b_start1 : f64, index + } + + scf.yield %r2, %next_b_start1 : f64, index + } + + scf.yield %r1, %next_b_start0 : f64, index + } + + return %r0 : f64 +} + +// Vector dot product with a while loop. Implemented as follows: +// +// r = 0.0, a = 0, b = 0 +// while (a < M && b < N) { +// segA = A[a:a+8], segB = B[b:b+8] +// if (segB[7] < segA[0]) b += 8 +// elif (segA[7] < segB[0]) a += 8 +// else { +// r += vector_dot(...) +// if (segA[7] < segB[7]) a += 8 +// elif (segB[7] < segA[7]) b += 8 +// else a += 8, b += 8 +// } +// } +func @memref_dot_while(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref, + %M : index, %N : index) + -> f64 { + // Helper constants for loops. + %c0 = constant 0 : index + %i0 = constant 0 : i32 + %i7 = constant 7 : i32 + %c8 = constant 8 : index + + %data_zero = constant 0.0 : f64 + %index_padding = constant 9223372036854775807 : i64 + + %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0) + : (f64, index, index) -> (f64, index, index) { + %cond_i = cmpi "slt", %a1, %M : index + %cond_j = cmpi "slt", %b1, %N : index + %cond = and %cond_i, %cond_j : i1 + scf.condition(%cond) %r1, %a1, %b1 : f64, index, index + } do { + ^bb0(%r1 : f64, %a1 : index, %b1 : index): + // v_A, v_B, seg*_* could be part of the loop state to avoid a few + // redundant reads. + %v_A = vector.transfer_read %m_A[%a1], %index_padding + : memref, vector<8xi64> + %v_C = vector.transfer_read %m_C[%b1], %index_padding + : memref, vector<8xi64> + + %segA_min = vector.extractelement %v_A[%i0 : i32] : vector<8xi64> + %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> + %segB_min = vector.extractelement %v_C[%i0 : i32] : vector<8xi64> + %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + + %seg1_done = cmpi "slt", %segB_max, %segA_min : i64 + %r2, %a2, %b2 = scf.if %seg1_done -> (f64, index, index) { + %b3 = addi %b1, %c8 : index + scf.yield %r1, %a1, %b3 : f64, index, index + } else { + %seg0_done = cmpi "slt", %segA_max, %segB_min : i64 + %r4, %a4, %b4 = scf.if %seg0_done -> (f64, index, index) { + %a5 = addi %a1, %c8 : index + scf.yield %r1, %a5, %b1 : f64, index, index + } else { + %v_B = vector.transfer_read %m_B[%a1], %data_zero + : memref, vector<8xf64> + %v_D = vector.transfer_read %m_D[%b1], %data_zero + : memref, vector<8xf64> + + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) + -> f64 + %r6 = addf %r1, %subresult : f64 + + %incr_a = cmpi "slt", %segA_max, %segB_max : i64 + %a6, %b6 = scf.if %incr_a -> (index, index) { + %a7 = addi %a1, %c8 : index + scf.yield %a7, %b1 : index, index + } else { + %incr_b = cmpi "slt", %segB_max, %segA_max : i64 + %a8, %b8 = scf.if %incr_b -> (index, index) { + %b9 = addi %b1, %c8 : index + scf.yield %a1, %b9 : index, index + } else { + %a10 = addi %a1, %c8 : index + %b10 = addi %b1, %c8 : index + scf.yield %a10, %b10 : index, index + } + scf.yield %a8, %b8 : index, index + } + scf.yield %r6, %a6, %b6 : f64, index, index + } + scf.yield %r4, %a4, %b4 : f64, index, index + } + scf.yield %r2, %a2, %b2 : f64, index, index + } + + return %r0 : f64 +} + +// Vector dot product with a while loop that has no branches (apart from the +// while loop itself). Implemented as follows: +// +// r = 0.0, a = 0, b = 0 +// while (a < M && b < N) { +// segA = A[a:a+8], segB = B[b:b+8] +// r += vector_dot(...) +// a += (segA[7] <= segB[7]) * 8 +// b += (segB[7] <= segA[7]) * 8 +// } +func @memref_dot_while_branchless(%m_A : memref, %m_B : memref, + %m_C : memref, %m_D : memref, + %M : index, %N : index) + -> f64 { + // Helper constants for loops. + %c0 = constant 0 : index + %i7 = constant 7 : i32 + %c8 = constant 8 : index + + %data_zero = constant 0.0 : f64 + %index_padding = constant 9223372036854775807 : i64 + + %r0, %a0, %b0 = scf.while (%r1 = %data_zero, %a1 = %c0, %b1 = %c0) + : (f64, index, index) -> (f64, index, index) { + %cond_i = cmpi "slt", %a1, %M : index + %cond_j = cmpi "slt", %b1, %N : index + %cond = and %cond_i, %cond_j : i1 + scf.condition(%cond) %r1, %a1, %b1 : f64, index, index + } do { + ^bb0(%r1 : f64, %a1 : index, %b1 : index): + // v_A, v_B, seg*_* could be part of the loop state to avoid a few + // redundant reads. + %v_A = vector.transfer_read %m_A[%a1], %index_padding + : memref, vector<8xi64> + %v_B = vector.transfer_read %m_B[%a1], %data_zero + : memref, vector<8xf64> + %v_C = vector.transfer_read %m_C[%b1], %index_padding + : memref, vector<8xi64> + %v_D = vector.transfer_read %m_D[%b1], %data_zero + : memref, vector<8xf64> + + %subresult = call @vector_dot(%v_A, %v_B, %v_C, %v_D) + : (vector<8xi64>, vector<8xf64>, vector<8xi64>, vector<8xf64>) + -> f64 + %r2 = addf %r1, %subresult : f64 + + %segA_max = vector.extractelement %v_A[%i7 : i32] : vector<8xi64> + %segB_max = vector.extractelement %v_C[%i7 : i32] : vector<8xi64> + + %cond_a = cmpi "sle", %segA_max, %segB_max : i64 + %cond_a_i64 = zexti %cond_a : i1 to i64 + %cond_a_idx = index_cast %cond_a_i64 : i64 to index + %incr_a = muli %cond_a_idx, %c8 : index + %a2 = addi %a1, %incr_a : index + + %cond_b = cmpi "sle", %segB_max, %segA_max : i64 + %cond_b_i64 = zexti %cond_b : i1 to i64 + %cond_b_idx = index_cast %cond_b_i64 : i64 to index + %incr_b = muli %cond_b_idx, %c8 : index + %b2 = addi %b1, %incr_b : index + + scf.yield %r2, %a2, %b2 : f64, index, index + } + + return %r0 : f64 +} + +func @entry() -> i32 { + // Initialize large buffers that can be used for multiple test cases of + // different sizes. + %b_A = memref.alloc() : memref<128xi64> + %b_B = memref.alloc() : memref<128xf64> + %b_C = memref.alloc() : memref<128xi64> + %b_D = memref.alloc() : memref<128xf64> + + %m_A = memref.cast %b_A : memref<128xi64> to memref + %m_B = memref.cast %b_B : memref<128xf64> to memref + %m_C = memref.cast %b_C : memref<128xi64> to memref + %m_D = memref.cast %b_D : memref<128xf64> to memref + + // --- Test case 1 ---. + // M and N must be a multiple of 8 if smaller than 128. + // (Because padding kicks in only for out-of-bounds accesses.) + %M1, %N1 = call @fill_input_1(%m_A, %m_B, %m_C, %m_D) + : (memref, memref, memref, memref) + -> (index, index) + + %r0 = call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M1, %N1) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r0 : f64 + // CHECK: 86 + + %r1 = call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M1, %N1) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r1 : f64 + // CHECK: 86 + + %r2 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M1, %N1) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r2 : f64 + // CHECK: 86 + + %r6 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M1, %N1) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r6 : f64 + // CHECK: 86 + + // --- Test case 2 ---. + // M and N must be a multiple of 8 if smaller than 128. + // (Because padding kicks in only for out-of-bounds accesses.) + %M2, %N2 = call @fill_input_2(%m_A, %m_B, %m_C, %m_D) + : (memref, memref, memref, memref) + -> (index, index) + + %r3 = call @memref_dot_simple(%m_A, %m_B, %m_C, %m_D, %M2, %N2) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r3 : f64 + // CHECK: 111 + + %r4 = call @memref_dot_optimized(%m_A, %m_B, %m_C, %m_D, %M2, %N2) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r4 : f64 + // CHECK: 111 + + %r5 = call @memref_dot_while(%m_A, %m_B, %m_C, %m_D, %M2, %N2) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r5 : f64 + // CHECK: 111 + + %r7 = call @memref_dot_while_branchless(%m_A, %m_B, %m_C, %m_D, %M2, %N2) + : (memref, memref, memref, memref, + index, index) -> f64 + vector.print %r7 : f64 + // CHECK: 111 + + // Release all resources. + memref.dealloc %b_A : memref<128xi64> + memref.dealloc %b_B : memref<128xf64> + memref.dealloc %b_C : memref<128xi64> + memref.dealloc %b_D : memref<128xf64> + + %r = constant 0 : i32 + return %r : i32 +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-vp2intersect-i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-vp2intersect-i32.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/X86Vector/test-vp2intersect-i32.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-x86vector" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AVX512 vp2intersect. + +func @entry() -> i32 { + %i0 = constant 0 : i32 + %i1 = constant 1: i32 + %i2 = constant 2: i32 + %i3 = constant 7: i32 + %i4 = constant 12: i32 + %i5 = constant -10: i32 + %i6 = constant -219: i32 + + %v0 = vector.broadcast %i1 : i32 to vector<16xi32> + %v1 = vector.insert %i2, %v0[1] : i32 into vector<16xi32> + %v2 = vector.insert %i3, %v1[4] : i32 into vector<16xi32> + %v3 = vector.insert %i4, %v2[6] : i32 into vector<16xi32> + %v4 = vector.insert %i5, %v3[7] : i32 into vector<16xi32> + %v5 = vector.insert %i0, %v4[10] : i32 into vector<16xi32> + %v6 = vector.insert %i0, %v5[12] : i32 into vector<16xi32> + %v7 = vector.insert %i3, %v6[13] : i32 into vector<16xi32> + %v8 = vector.insert %i3, %v7[14] : i32 into vector<16xi32> + %v9 = vector.insert %i0, %v8[15] : i32 into vector<16xi32> + vector.print %v9 : vector<16xi32> + // CHECK: ( 1, 2, 1, 1, 7, 1, 12, -10, 1, 1, 0, 1, 0, 7, 7, 0 ) + + %w0 = vector.broadcast %i1 : i32 to vector<16xi32> + %w1 = vector.insert %i2, %w0[4] : i32 into vector<16xi32> + %w2 = vector.insert %i6, %w1[7] : i32 into vector<16xi32> + %w3 = vector.insert %i4, %w2[8] : i32 into vector<16xi32> + %w4 = vector.insert %i4, %w3[9] : i32 into vector<16xi32> + %w5 = vector.insert %i4, %w4[10] : i32 into vector<16xi32> + %w6 = vector.insert %i0, %w5[11] : i32 into vector<16xi32> + %w7 = vector.insert %i0, %w6[12] : i32 into vector<16xi32> + %w8 = vector.insert %i0, %w7[13] : i32 into vector<16xi32> + %w9 = vector.insert %i0, %w8[15] : i32 into vector<16xi32> + vector.print %w9 : vector<16xi32> + // CHECK: ( 1, 1, 1, 1, 2, 1, 1, -219, 12, 12, 12, 0, 0, 0, 1, 0 ) + + %k1, %k2 = x86vector.avx512.vp2intersect %v9, %w9 : vector<16xi32> + + vector.print %k1 : vector<16xi1> + // CHECK: ( 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1 ) + + vector.print %k2 : vector<16xi1> + // CHECK: ( 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1 ) + + return %i0 : i32 +} diff --git a/mlir/test/Target/LLVMIR/x86vector.mlir b/mlir/test/Target/LLVMIR/x86vector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/x86vector.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512_mask_ps_512 +llvm.func @LLVM_x86_avx512_mask_ps_512(%a: vector<16 x f32>, + %c: i16) + -> (vector<16 x f32>) +{ + %b = llvm.mlir.constant(42 : i32) : i32 + // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> + %0 = "x86vector.avx512.intr.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) : + (vector<16 x f32>, i32, vector<16 x f32>, i16, i32) -> vector<16 x f32> + // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float> + %1 = "x86vector.avx512.intr.mask.scalef.ps.512"(%a, %a, %a, %c, %b) : + (vector<16 x f32>, vector<16 x f32>, vector<16 x f32>, i16, i32) -> vector<16 x f32> + llvm.return %1: vector<16 x f32> +} + +// CHECK-LABEL: define <8 x double> @LLVM_x86_avx512_mask_pd_512 +llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>, + %c: i8) + -> (vector<8xf64>) +{ + %b = llvm.mlir.constant(42 : i32) : i32 + // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> + %0 = "x86vector.avx512.intr.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) : + (vector<8xf64>, i32, vector<8xf64>, i8, i32) -> vector<8xf64> + // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> + %1 = "x86vector.avx512.intr.mask.scalef.pd.512"(%a, %a, %a, %c, %b) : + (vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64> + llvm.return %1: vector<8xf64> +} + +// CHECK-LABEL: define <16 x float> @LLVM_x86_mask_compress +llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>) + -> vector<16xf32> +{ + // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32( + %0 = "x86vector.avx512.intr.mask.compress"(%a, %a, %k) : + (vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32> + llvm.return %0 : vector<16xf32> +} + +// CHECK-LABEL: define { <16 x i1>, <16 x i1> } @LLVM_x86_vp2intersect_d_512 +llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>) + -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> +{ + // CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32> + %0 = "x86vector.avx512.intr.vp2intersect.d.512"(%a, %b) : + (vector<16xi32>, vector<16xi32>) -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> + llvm.return %0 : !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> +} + +// CHECK-LABEL: define { <8 x i1>, <8 x i1> } @LLVM_x86_vp2intersect_q_512 +llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>) + -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> +{ + // CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64> + %0 = "x86vector.avx512.intr.vp2intersect.q.512"(%a, %b) : + (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> + llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> +} diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -49,6 +49,7 @@ config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@" config.mlir_run_amx_tests = "@MLIR_RUN_AMX_TESTS@" config.mlir_run_avx512_tests = "@MLIR_RUN_AVX512_TESTS@" +config.mlir_run_x86vector_tests = "@MLIR_RUN_X86VECTOR_TESTS@" config.mlir_include_integration_tests = "@MLIR_INCLUDE_INTEGRATION_TESTS@" # Support substitution of the tools_dir with user parameters. This is