diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h deleted file mode 100644 --- a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h +++ /dev/null @@ -1,23 +0,0 @@ -//===- ConvertAVX512ToLLVM.h - Conversion Patterns from AVX512 to LLVM ----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ -#define MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ - -namespace mlir { - -class LLVMTypeConverter; -class OwningRewritePatternList; - -/// Collect a set of patterns to convert from the AVX512 dialect to LLVM. -void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns); - -} // namespace mlir - -#endif // MLIR_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td --- a/mlir/include/mlir/Dialect/AVX512/AVX512.td +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -14,6 +14,7 @@ #define AVX512_OPS include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" //===----------------------------------------------------------------------===// // AVX512 dialect definition @@ -31,6 +32,24 @@ class AVX512_Op traits = []> : Op {} +class AVX512_IntrOp traits = []> : + LLVM_IntrOpBase; + +// Defined by first result overload. May have to be extended for other +// instructions in the future. +class AVX512_IntrOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], + traits, /*numResults=*/1>; +//----------------------------------------------------------------------------// +// MaskCompressOp +//----------------------------------------------------------------------------// + def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect, // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could // then be removed from assemblyFormat. @@ -67,6 +86,25 @@ " `:` type($dst) (`,` type($src)^)?"; } +def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ + NoSideEffect, + AllTypesMatch<["a", "src", "res"]>, + TypesMatchWith<"`k` has the same number of bits as elements in `res`", + "res", "k", + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">]> { + let arguments = (ins VectorOfLengthAndType<[16, 8], + [F32, I32, F64, I64]>:$a, + VectorOfLengthAndType<[16, 8], + [F32, I32, F64, I64]>:$src, + VectorOfLengthAndType<[16, 8], + [I1]>:$k); +} + +//----------------------------------------------------------------------------// +// MaskRndScaleOp +//----------------------------------------------------------------------------// + def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect, AllTypesMatch<["src", "a", "dst"]>, TypesMatchWith<"imm has the same number of bits as elements in dst", @@ -99,6 +137,30 @@ "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst)"; } +def MaskRndScalePSIntrOp : AVX512_IntrOp<"mask.rndscale.ps.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, + I32:$k, + VectorOfLengthAndType<[16], [F32]>:$a, + I16:$imm, + LLVM_Type:$rounding); +} + +def MaskRndScalePDIntrOp : AVX512_IntrOp<"mask.rndscale.pd.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, + I32:$k, + VectorOfLengthAndType<[8], [F64]>:$a, + I8:$imm, + LLVM_Type:$rounding); +} + +//----------------------------------------------------------------------------// +// MaskScaleFOp +//----------------------------------------------------------------------------// + def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect, AllTypesMatch<["src", "a", "b", "dst"]>, TypesMatchWith<"k has the same number of bits as elements in dst", @@ -132,6 +194,30 @@ "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)"; } +def MaskScaleFPSIntrOp : AVX512_IntrOp<"mask.scalef.ps.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "b", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[16], [F32]>:$src, + VectorOfLengthAndType<[16], [F32]>:$a, + VectorOfLengthAndType<[16], [F32]>:$b, + I16:$k, + LLVM_Type:$rounding); +} + +def MaskScaleFPDIntrOp : AVX512_IntrOp<"mask.scalef.pd.512", 1, [ + NoSideEffect, + AllTypesMatch<["src", "a", "b", "res"]>]> { + let arguments = (ins VectorOfLengthAndType<[8], [F64]>:$src, + VectorOfLengthAndType<[8], [F64]>:$a, + VectorOfLengthAndType<[8], [F64]>:$b, + I8:$k, + LLVM_Type:$rounding); +} + +//----------------------------------------------------------------------------// +// Vp2IntersectOp +//----------------------------------------------------------------------------// + def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect, AllTypesMatch<["a", "b"]>, TypesMatchWith<"k1 has the same number of bits as elements in a", @@ -169,4 +255,16 @@ "$a `,` $b attr-dict `:` type($a)"; } +def Vp2IntersectDIntrOp : AVX512_IntrOp<"vp2intersect.d.512", 2, [ + NoSideEffect]> { + let arguments = (ins VectorOfLengthAndType<[16], [I32]>:$a, + VectorOfLengthAndType<[16], [I32]>:$b); +} + +def Vp2IntersectQIntrOp : AVX512_IntrOp<"vp2intersect.q.512", 2, [ + NoSideEffect]> { + let arguments = (ins VectorOfLengthAndType<[8], [I64]>:$a, + VectorOfLengthAndType<[8], [I64]>:$b); +} + #endif // AVX512_OPS diff --git a/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt --- a/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AVX512/CMakeLists.txt @@ -1,2 +1,6 @@ add_mlir_dialect(AVX512 avx512) add_mlir_doc(AVX512 -gen-dialect-doc AVX512 Dialects/) + +set(LLVM_TARGET_DEFINITIONS AVX512.td) +mlir_tablegen(AVX512Conversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRAVX512ConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/AVX512/Transforms.h b/mlir/include/mlir/Dialect/AVX512/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AVX512/Transforms.h @@ -0,0 +1,29 @@ +//===- Transforms.h - AVX512 Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AVX512_TRANSFORMS_H +#define MLIR_DIALECT_AVX512_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class OwningRewritePatternList; + +/// Collect a set of patterns to lower AVX512 ops to ops that map to LLVM +/// intrinsics. +void populateAVX512LegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Configure the target to support lowering AVX512 ops to ops that map to LLVM +/// intrinsics. +void configureAVX512LegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_AVX512_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -35,12 +35,6 @@ mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRROCDLConversionsIncGen) -add_mlir_dialect(LLVMAVX512 llvm_avx512 LLVMAVX512) -add_mlir_doc(LLVMAVX512 -gen-dialect-doc LLVMAVX512 Dialects/) -set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) -mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) -add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) - add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE) add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/) set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td +++ /dev/null @@ -1,74 +0,0 @@ -//===-- LLVMAVX512.td - LLVMAVX512 dialect op definitions --*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the basic operations for the LLVMAVX512 dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVMIR_AVX512_OPS -#define LLVMIR_AVX512_OPS - -include "mlir/Dialect/LLVMIR/LLVMOpBase.td" - -//===----------------------------------------------------------------------===// -// LLVMAVX512 dialect definition -//===----------------------------------------------------------------------===// - -def LLVMAVX512_Dialect : Dialect { - let name = "llvm_avx512"; - let cppNamespace = "::mlir::LLVM"; -} - -//----------------------------------------------------------------------------// -// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system -//----------------------------------------------------------------------------// - -class LLVMAVX512_IntrOp traits = []> : - LLVM_IntrOpBase; - -// Defined by first result overload. May have to be extended for other -// instructions in the future. -class LLVMAVX512_IntrOverloadedOp traits = []> : - LLVM_IntrOpBase overloadedResults=*/[0], - /*list overloadedOperands=*/[], - traits, /*numResults=*/1>; - -def LLVM_x86_avx512_mask_rndscale_ps_512 : - LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_x86_avx512_mask_rndscale_pd_512 : - LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_x86_avx512_mask_scalef_ps_512 : - LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_x86_avx512_mask_scalef_pd_512 : - LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_x86_avx512_mask_compress : - LLVMAVX512_IntrOverloadedOp<"mask.compress">, - Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; - -def LLVM_x86_avx512_vp2intersect_d_512 : - LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>, - Arguments<(ins LLVM_Type, LLVM_Type)>; - -def LLVM_x86_avx512_vp2intersect_q_512 : - LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>, - Arguments<(ins LLVM_Type, LLVM_Type)>; - -#endif // AVX512_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h +++ /dev/null @@ -1,24 +0,0 @@ -//===- LLVMAVX512Dialect.h - MLIR Dialect for LLVMAVX512 --------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the Target dialect for LLVMAVX512 in MLIR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_ -#define MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_ - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" - -#define GET_OP_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMAVX512.h.inc" - -#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h.inc" - -#endif // MLIR_DIALECT_LLVMIR_LLVMAVX512DIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -21,7 +21,6 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -55,7 +54,6 @@ avx512::AVX512Dialect, complex::ComplexDialect, gpu::GPUDialect, - LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h @@ -0,0 +1,32 @@ +//===- AVX512ToLLVMIRTranslation.h - AVX512 to LLVM IR ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for AVX512 dialect to LLVM IR +// translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_AVX512_AVX512TOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_AVX512_AVX512TOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the AVX512 dialect and the translation from it to the LLVM IR +/// in the given registry; +void registerAVX512DialectTranslation(DialectRegistry ®istry); + +/// Register the AVX512 dialect and the translation from it in the registry +/// associated with the given context. +void registerAVX512DialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_AVX512_AVX512TOLLVMIRTRANSLATION_H diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -14,8 +14,8 @@ #ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H #define MLIR_TARGET_LLVMIR_DIALECT_ALL_H +#include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMArmSVE/LLVMArmSVEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" @@ -29,7 +29,7 @@ /// corresponding translation interfaces. static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerArmNeonDialectTranslation(registry); - registerLLVMAVX512DialectTranslation(registry); + registerAVX512DialectTranslation(registry); registerLLVMArmSVEDialectTranslation(registry); registerLLVMDialectTranslation(registry); registerNVVMDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h deleted file mode 100644 --- a/mlir/include/mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h +++ /dev/null @@ -1,32 +0,0 @@ -//===- LLVMAVX512ToLLVMIRTranslation.h - LLVMAVX512 to LLVM IR --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This provides registration calls for LLVMAVX512 dialect to LLVM IR -// translation. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TARGET_LLVMIR_DIALECT_LLVMAVX512_LLVMAVX512TOLLVMIRTRANSLATION_H -#define MLIR_TARGET_LLVMIR_DIALECT_LLVMAVX512_LLVMAVX512TOLLVMIRTRANSLATION_H - -namespace mlir { - -class DialectRegistry; -class MLIRContext; - -/// Register the LLVMAVX512 dialect and the translation from it to the LLVM IR -/// in the given registry; -void registerLLVMAVX512DialectTranslation(DialectRegistry ®istry); - -/// Register the LLVMAVX512 dialect and the translation from it in the registry -/// associated with the given context. -void registerLLVMAVX512DialectTranslation(MLIRContext &context); - -} // namespace mlir - -#endif // MLIR_TARGET_LLVMIR_DIALECT_LLVMAVX512_LLVMAVX512TOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt b/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(MLIRAVX512ToLLVM - ConvertAVX512ToLLVM.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AVX512ToLLVM - - DEPENDS - MLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRAVX512 - MLIRLLVMAVX512 - MLIRLLVMIR - MLIRStandardToLLVM - MLIRTransforms - ) diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp deleted file mode 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ /dev/null @@ -1,143 +0,0 @@ -//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" - -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/Dialect/AVX512/AVX512Dialect.h" -#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/PatternMatch.h" - -using namespace mlir; -using namespace mlir::vector; -using namespace mlir::avx512; - -template -static Type getSrcVectorElementType(Operation *op) { - return cast(op) - .src() - .getType() - .template cast() - .getElementType(); -} - -namespace { - -// TODO: turn these into simpler declarative templated patterns when we've had -// enough. -struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern { - explicit MaskRndScaleOp512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, - typeConverter) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Type elementType = getSrcVectorElementType(op); - if (elementType.isF32()) - return LLVM::detail::oneToOneRewrite( - op, LLVM::x86_avx512_mask_rndscale_ps_512::getOperationName(), - operands, *getTypeConverter(), rewriter); - if (elementType.isF64()) - return LLVM::detail::oneToOneRewrite( - op, LLVM::x86_avx512_mask_rndscale_pd_512::getOperationName(), - operands, *getTypeConverter(), rewriter); - return failure(); - } -}; - -struct MaskCompressOpConversion - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(MaskCompressOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - MaskCompressOp::Adaptor adaptor(operands); - auto opType = adaptor.a().getType(); - - Value src; - if (op.src()) { - src = adaptor.src(); - } else if (op.constant_src()) { - src = rewriter.create(op.getLoc(), opType, - op.constant_srcAttr()); - } else { - Attribute zeroAttr = rewriter.getZeroAttr(opType); - src = rewriter.create(op->getLoc(), opType, zeroAttr); - } - - rewriter.replaceOpWithNewOp( - op, opType, adaptor.a(), src, adaptor.k()); - - return success(); - } -}; - -struct ScaleFOp512Conversion : public ConvertToLLVMPattern { - explicit ScaleFOp512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, - typeConverter) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Type elementType = getSrcVectorElementType(op); - if (elementType.isF32()) - return LLVM::detail::oneToOneRewrite( - op, LLVM::x86_avx512_mask_scalef_ps_512::getOperationName(), operands, - *getTypeConverter(), rewriter); - if (elementType.isF64()) - return LLVM::detail::oneToOneRewrite( - op, LLVM::x86_avx512_mask_scalef_pd_512::getOperationName(), operands, - *getTypeConverter(), rewriter); - return failure(); - } -}; - -struct Vp2IntersectOp512Conversion - : public ConvertOpToLLVMPattern { - explicit Vp2IntersectOp512Conversion(MLIRContext *context, - LLVMTypeConverter &typeConverter) - : ConvertOpToLLVMPattern(typeConverter) {} - - LogicalResult - matchAndRewrite(Vp2IntersectOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Type elementType = - op.a().getType().template cast().getElementType(); - if (elementType.isInteger(32)) - return LLVM::detail::oneToOneRewrite( - op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands, - *getTypeConverter(), rewriter); - if (elementType.isInteger(64)) - return LLVM::detail::oneToOneRewrite( - op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands, - *getTypeConverter(), rewriter); - return failure(); - } -}; -} // namespace - -/// Populate the given list with patterns that convert from AVX512 to LLVM. -void mlir::populateAVX512ToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - // clang-format off - patterns.insert(&converter.getContext(), - converter); - patterns.insert(converter); - // clang-format on -} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,6 +1,5 @@ add_subdirectory(AffineToStandard) add_subdirectory(AsyncToLLVM) -add_subdirectory(AVX512ToLLVM) add_subdirectory(ComplexToLLVM) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -15,8 +15,7 @@ LINK_LIBS PUBLIC MLIRArmNeon MLIRAVX512 - MLIRAVX512ToLLVM - MLIRLLVMAVX512 + MLIRAVX512Transforms MLIRArmSVE MLIRArmSVEToLLVM MLIRLLVMArmSVE diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -10,14 +10,13 @@ #include "../PassDetail.h" -#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/AVX512/Transforms.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -45,7 +44,7 @@ if (enableArmSVE) registry.insert(); if (enableAVX512) - registry.insert(); + registry.insert(); } void runOnOperation() override; }; @@ -104,9 +103,8 @@ populateArmSVEToLLVMConversionPatterns(converter, patterns); } if (enableAVX512) { - target.addLegalDialect(); - target.addIllegalDialect(); - populateAVX512ToLLVMConversionPatterns(converter, patterns); + configureAVX512LegalizeForExportTarget(target); + populateAVX512LegalizeForLLVMExportPatterns(converter, patterns); } if (failed( diff --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/CMakeLists.txt --- a/mlir/lib/Dialect/AVX512/CMakeLists.txt +++ b/mlir/lib/Dialect/AVX512/CMakeLists.txt @@ -1,13 +1,2 @@ -add_mlir_dialect_library(MLIRAVX512 - IR/AVX512Dialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AVX512 - - DEPENDS - MLIRAVX512IncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRSideEffectInterfaces - ) +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp --- a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp +++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp @@ -11,7 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/AVX512/AVX512Dialect.h" -#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Dialect/AVX512/IR/CMakeLists.txt copy from mlir/lib/Dialect/AVX512/CMakeLists.txt copy to mlir/lib/Dialect/AVX512/IR/CMakeLists.txt --- a/mlir/lib/Dialect/AVX512/CMakeLists.txt +++ b/mlir/lib/Dialect/AVX512/IR/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRAVX512 - IR/AVX512Dialect.cpp + AVX512Dialect.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AVX512 @@ -9,5 +9,6 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMIR MLIRSideEffectInterfaces ) diff --git a/mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AVX512/Transforms/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_dialect_library(MLIRAVX512Transforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRAVX512ConversionsIncGen + + LINK_LIBS PUBLIC + MLIRAVX512 + MLIRIR + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AVX512/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,141 @@ +//===- LegalizeForLLVMExport.cpp - Prepare AVX512 for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AVX512/Transforms.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::avx512; + +/// Extracts the "main" vector element type from the given AVX512 operation. +template +static Type getSrcVectorElementType(OpTy op) { + return op.src().getType().template cast().getElementType(); +} +template <> +Type getSrcVectorElementType(Vp2IntersectOp op) { + return op.a().getType().template cast().getElementType(); +} + +namespace { +/// Base conversion for AVX512 ops that can be lowered to one of the two +/// intrinsics based on the bitwidth of their "main" vector element type. This +/// relies on the to-LLVM-dialect conversion helpers to correctly pack the +/// results of multi-result intrinsic ops. +template +struct LowerToIntrinsic : public OpConversionPattern { + explicit LowerToIntrinsic(LLVMTypeConverter &converter) + : OpConversionPattern(converter, &converter.getContext()) {} + + LLVMTypeConverter &getTypeConverter() const { + return *static_cast( + OpConversionPattern::getTypeConverter()); + } + + LogicalResult + matchAndRewrite(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Type elementType = getSrcVectorElementType(op); + unsigned bitwidth = elementType.getIntOrFloatBitWidth(); + if (bitwidth == 32) + return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), + operands, getTypeConverter(), + rewriter); + if (bitwidth == 64) + return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), + operands, getTypeConverter(), + rewriter); + return rewriter.notifyMatchFailure( + op, "expected 'src' to be either f32 or f64"); + } +}; + +struct MaskCompressOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MaskCompressOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MaskCompressOp::Adaptor adaptor(operands); + auto opType = adaptor.a().getType(); + + Value src; + if (op.src()) { + src = adaptor.src(); + } else if (op.constant_src()) { + src = rewriter.create(op.getLoc(), opType, + op.constant_srcAttr()); + } else { + Attribute zeroAttr = rewriter.getZeroAttr(opType); + src = rewriter.create(op->getLoc(), opType, zeroAttr); + } + + rewriter.replaceOpWithNewOp(op, opType, adaptor.a(), + src, adaptor.k()); + + return success(); + } +}; + +/// An entry associating the "main" AVX512 op with its instantiations for +/// vectors of 32-bit and 64-bit elements. +template +struct RegEntry { + using MainOp = OpTy; + using Intr32Op = Intr32OpTy; + using Intr64Op = Intr64OpTy; +}; + +/// A container for op association entries facilitating the configuration of +/// dialect conversion. +template +struct RegistryImpl { + /// Registers the patterns specializing the "main" op to one of the + /// "intrinsic" ops depending on elemental type. + static void registerPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + patterns + .insert...>(converter); + } + + /// Configures the conversion target to lower out "main" ops. + static void configureTarget(LLVMConversionTarget &target) { + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + } +}; + +using Registry = RegistryImpl< + RegEntry, + RegEntry, + RegEntry>; + +} // namespace + +/// Populate the given list with patterns that convert from AVX512 to LLVM. +void mlir::populateAVX512LegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + Registry::registerPatterns(converter, patterns); + patterns.insert(converter); +} + +void mlir::configureAVX512LegalizeForExportTarget( + LLVMConversionTarget &target) { + Registry::configureTarget(target); + target.addLegalOp(); + target.addIllegalOp(); +} diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -29,27 +29,6 @@ MLIRSupport ) -add_mlir_dialect_library(MLIRLLVMAVX512 - IR/LLVMAVX512Dialect.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR - - DEPENDS - MLIRLLVMAVX512IncGen - MLIRLLVMAVX512ConversionsIncGen - intrinsics_gen - - LINK_COMPONENTS - AsmParser - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRLLVMIR - MLIRSideEffectInterfaces - ) - add_mlir_dialect_library(MLIRLLVMArmSVE IR/LLVMArmSVEDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAVX512Dialect.cpp +++ /dev/null @@ -1,31 +0,0 @@ -//===- LLVMAVX512Dialect.cpp - MLIR LLVMAVX512 ops implementation ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the LLVMAVX512 dialect and its operations. -// -//===----------------------------------------------------------------------===// - -#include "llvm/IR/IntrinsicsX86.h" - -#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" - -using namespace mlir; - -void LLVM::LLVMAVX512Dialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc" - >(); -} - -#define GET_OP_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMAVX512.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -37,8 +37,8 @@ LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation + MLIRAVX512ToLLVMIRTranslation MLIRLLVMArmSVEToLLVMIRTranslation - MLIRLLVMAVX512ToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp rename from mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp rename to mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.cpp @@ -1,4 +1,4 @@ -//===- LLVMAVX512ToLLVMIRTranslation.cpp - Translate LLVMAVX512 to LLVM IR-===// +//===- AVX512ToLLVMIRTranslation.cpp - Translate AVX512 to LLVM IR---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,13 +6,13 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation between the MLIR LLVMAVX512 dialect and +// This file implements a translation between the MLIR AVX512 dialect and // LLVM IR. // //===----------------------------------------------------------------------===// -#include "mlir/Target/LLVMIR/Dialect/LLVMAVX512/LLVMAVX512ToLLVMIRTranslation.h" -#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Target/LLVMIR/Dialect/AVX512/AVX512ToLLVMIRTranslation.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -24,8 +24,8 @@ namespace { /// Implementation of the dialect interface that converts operations belonging -/// to the LLVMAVX512 dialect to LLVM IR. -class LLVMAVX512DialectLLVMIRTranslationInterface +/// to the AVX512 dialect to LLVM IR. +class AVX512DialectLLVMIRTranslationInterface : public LLVMTranslationDialectInterface { public: using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; @@ -36,21 +36,21 @@ convertOperation(Operation *op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const final { Operation &opInst = *op; -#include "mlir/Dialect/LLVMIR/LLVMAVX512Conversions.inc" +#include "mlir/Dialect/AVX512/AVX512Conversions.inc" return failure(); } }; } // end namespace -void mlir::registerLLVMAVX512DialectTranslation(DialectRegistry ®istry) { - registry.insert(); - registry.addDialectInterface(); +void mlir::registerAVX512DialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); } -void mlir::registerLLVMAVX512DialectTranslation(MLIRContext &context) { +void mlir::registerAVX512DialectTranslation(MLIRContext &context) { DialectRegistry registry; - registerLLVMAVX512DialectTranslation(registry); + registerAVX512DialectTranslation(registry); context.appendDialectRegistry(registry); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/AVX512/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRAVX512ToLLVMIRTranslation + AVX512ToLLVMIRTranslation.cpp + + DEPENDS + MLIRAVX512ConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRAVX512 + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(ArmNeon) +add_subdirectory(AVX512) add_subdirectory(LLVMArmSVE) -add_subdirectory(LLVMAVX512) add_subdirectory(LLVMIR) add_subdirectory(NVVM) add_subdirectory(OpenMP) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMAVX512/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_mlir_translation_library(MLIRLLVMAVX512ToLLVMIRTranslation - LLVMAVX512ToLLVMIRTranslation.cpp - - DEPENDS - MLIRLLVMAVX512ConversionsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRLLVMAVX512 - MLIRLLVMIR - MLIRSupport - MLIRTargetLLVMIRExport - ) diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir rename from mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir rename to mlir/test/Dialect/AVX512/legalize-for-llvm.mlir --- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Dialect/AVX512/legalize-for-llvm.mlir @@ -3,14 +3,14 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) -> (vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>) { - // CHECK: llvm_avx512.mask.rndscale.ps.512 + // CHECK: avx512.intr.mask.rndscale.ps.512 %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32> - // CHECK: llvm_avx512.mask.rndscale.pd.512 + // CHECK: avx512.intr.mask.rndscale.pd.512 %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64> - // CHECK: llvm_avx512.mask.scalef.ps.512 + // CHECK: avx512.intr.mask.scalef.ps.512 %2 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32> - // CHECK: llvm_avx512.mask.scalef.pd.512 + // CHECK: avx512.intr.mask.scalef.pd.512 %3 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64> // Keep results alive. @@ -21,11 +21,11 @@ %k2: vector<8xi1>, %a2: vector<8xi64>) -> (vector<16xf32>, vector<16xf32>, vector<8xi64>) { - // CHECK: llvm_avx512.mask.compress + // CHECK: avx512.intr.mask.compress %0 = avx512.mask.compress %k1, %a1 : vector<16xf32> - // CHECK: llvm_avx512.mask.compress + // CHECK: avx512.intr.mask.compress %1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> - // CHECK: llvm_avx512.mask.compress + // CHECK: avx512.intr.mask.compress %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> } @@ -33,9 +33,9 @@ func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) { - // CHECK: llvm_avx512.vp2intersect.d.512 + // CHECK: avx512.intr.vp2intersect.d.512 %0, %1 = avx512.vp2intersect %a, %a : vector<16xi32> - // CHECK: llvm_avx512.vp2intersect.q.512 + // CHECK: avx512.intr.vp2intersect.q.512 %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> } diff --git a/mlir/test/Target/LLVMIR/avx512.mlir b/mlir/test/Target/LLVMIR/avx512.mlir --- a/mlir/test/Target/LLVMIR/avx512.mlir +++ b/mlir/test/Target/LLVMIR/avx512.mlir @@ -7,10 +7,10 @@ { %b = llvm.mlir.constant(42 : i32) : i32 // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> - %0 = "llvm_avx512.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) : + %0 = "avx512.intr.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) : (vector<16 x f32>, i32, vector<16 x f32>, i16, i32) -> vector<16 x f32> // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float> - %1 = "llvm_avx512.mask.scalef.ps.512"(%a, %a, %a, %c, %b) : + %1 = "avx512.intr.mask.scalef.ps.512"(%a, %a, %a, %c, %b) : (vector<16 x f32>, vector<16 x f32>, vector<16 x f32>, i16, i32) -> vector<16 x f32> llvm.return %1: vector<16 x f32> } @@ -22,10 +22,10 @@ { %b = llvm.mlir.constant(42 : i32) : i32 // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> - %0 = "llvm_avx512.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) : + %0 = "avx512.intr.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) : (vector<8xf64>, i32, vector<8xf64>, i8, i32) -> vector<8xf64> // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> - %1 = "llvm_avx512.mask.scalef.pd.512"(%a, %a, %a, %c, %b) : + %1 = "avx512.intr.mask.scalef.pd.512"(%a, %a, %a, %c, %b) : (vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64> llvm.return %1: vector<8xf64> } @@ -35,7 +35,7 @@ -> vector<16xf32> { // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32( - %0 = "llvm_avx512.mask.compress"(%a, %a, %k) : + %0 = "avx512.intr.mask.compress"(%a, %a, %k) : (vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32> llvm.return %0 : vector<16xf32> } @@ -45,7 +45,7 @@ -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> { // CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32> - %0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) : + %0 = "avx512.intr.vp2intersect.d.512"(%a, %b) : (vector<16xi32>, vector<16xi32>) -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> llvm.return %0 : !llvm.struct<(vector<16 x i1>, vector<16 x i1>)> } @@ -55,7 +55,7 @@ -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> { // CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64> - %0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) : + %0 = "avx512.intr.vp2intersect.q.512"(%a, %b) : (vector<8xi64>, vector<8xi64>) -> !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> llvm.return %0 : !llvm.struct<(vector<8 x i1>, vector<8 x i1>)> } diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -11,7 +11,6 @@ // CHECK-NEXT: linalg // CHECK-NEXT: llvm // CHECK-NEXT: llvm_arm_sve -// CHECK-NEXT: llvm_avx512 // CHECK-NEXT: math // CHECK-NEXT: nvvm // CHECK-NEXT: omp