diff --git a/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h @@ -0,0 +1,28 @@ +//===- ConvertAVX512ToLLVM.h - Conversion Patterns from AVX512 to LLVM ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ +#define MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class LLVMTypeConverter; +class ModuleOp; +template class OpPassBase; + +/// Collect a set of patterns to convert from the AVX512 dialect to LLVM. +void populateAVX512ToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert AVX512 operations to the LLVMIR dialect. +OpPassBase *createLowerAVX512ToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_EDGE_CONVERSION_AVX512TOLLVM_CONVERTAVX512TOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AVX512/AVX512Dialect.h @@ -0,0 +1,35 @@ +//===- AVX512Dialect.h - MLIR Dialect for AVX512 ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for AVX512 in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_AVX512DIALECT_H_ +#define MLIR_DIALECT_LLVMIR_AVX512DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace LLVM { + +#define GET_OP_CLASSES +#include "mlir/Dialect/AVX512/AVX512Ops.h.inc" + +class AVX512Dialect : public Dialect { +public: + explicit AVX512Dialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "avx512"; } +}; + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_AVX512DIALECT_H_ diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512Ops.td b/mlir/include/mlir/Dialect/AVX512/AVX512Ops.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AVX512/AVX512Ops.td @@ -0,0 +1,115 @@ +//===-- AVX512Ops.td - AVX512 dialect operation definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the AVX512 dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef AVX512_OPS +#define AVX512_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +def LLVM_AVX512_Dialect : Dialect { + let name = "avx512"; + let cppNamespace = "LLVM"; +} + +class AVX512_Op traits = []> : + Op {} + +def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect, + AllTypesMatch<["src", "a", "dst"]>]>, + // Supports vector<16xf32> and vector<8xf64>. + Arguments<(ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, + I32:$k, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, + AnyTypeOf<[I16, I8]>:$imm, + // TODO(ntv): figure rounding out (optional operand?). + I32:$rounding + )>, + Results<(outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst)> { + let summary = "Masked roundscale op"; + let description = [{ + The mask.rndscale is an AVX512 specific op that can lower to the proper + `llvm::Intrinsic::x86_avx512_mask_rndscale_ps_512` or + `llvm::Intrinsic::x86_avx512_mask_rndscale_pd_512` instruction depending on + the type of MLIR vectors it is applied to. + + From the Intel Intrinsics Guide: + ================================ + Round packed floating-point elements in `a` to the number of fraction bits + specified by `imm`, and store the results in `dst` using writemask `k` + (elements are copied from src when the corresponding mask bit is not set). + }]; + // Fully specified by traits. + let verifier = ?; + let assemblyFormat = + // TODO(riverriddle, ntv): type($imm) should be dependent on type($dst). + "$src `,` $k `,` $a `,` $imm `,` $rounding attr-dict `:` type($dst) `,` type($imm)"; +} + +def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect, + AllTypesMatch<["src", "a", "b", "dst"]>]>, + // Supports vector<16xf32> and vector<8xf64>. + Arguments<(ins VectorOfLengthAndType<[16, 8], [F32, F64]>:$src, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$a, + VectorOfLengthAndType<[16, 8], [F32, F64]>:$b, + AnyTypeOf<[I16, I8]>:$k, + // TODO(ntv): figure rounding out (optional operand?). + I32:$rounding + )>, + Results<(outs VectorOfLengthAndType<[16, 8], [F32, F64]>:$dst)> { + let summary = "ScaleF op"; + let description = [{ + The scalef is an AVX512 specific op that can lower to the proper + `llvm::Intrinsic::x86_avx512_mask_scalef_ps_512` or + `llvm::Intrinsic::x86_avx512_mask_scalef_pd_512` instruction depending on + the type of MLIR vectors it is applied to. + + From the Intel Intrinsics Guide: + ================================ + Scale the packed floating-point elements in `a` using values from `b`, and + store the results in `dst` using writemask `k` (elements are copied from src + when the corresponding mask bit is not set). + }]; + // Fully specified by traits. + let verifier = ?; + let assemblyFormat = + // TODO(riverriddle, ntv): type($k) should be dependent on type($dst). + "$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst) `,` type($k)"; +} + +//============================================================================// +// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system +//============================================================================// + +// Note that AVX512 does not seem to like the mangling part of getDeclaration(). +class AVX512_IntrOp traits = []> : + LLVM_IntrOpBase; + +def LLVM_x86_avx512_mask_rndscale_ps_512 : + AVX512_IntrOp<"mask.rndscale.ps.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_mask_rndscale_pd_512 : + AVX512_IntrOp<"mask.rndscale.pd.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_mask_scalef_ps_512 : + AVX512_IntrOp<"mask.scalef.ps.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_avx512_mask_scalef_pd_512 : + AVX512_IntrOp<"mask.scalef.pd.512">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +#endif // AVX512_OPS diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -0,0 +1,219 @@ +//===- ConvertAVX512ToLLVM.cpp - Convert AVX512 to the LLVM dialect -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h"// Copyright 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + + +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; +using namespace mlir::vector; +using namespace mlir::LLVM; + +// TODO(ntv, zinenko): Code is currently copy-pasted and adapted from the code +// 1-1 LLVM conversion. It would better if it were properly exposed in core and +// reusable. +/// Basic lowering implementation for one-to-one rewriting from AVX512 Ops to +/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass +/// operands as is, preserve attributes. +template +PatternMatchResult matchAndRewriteOneToOne( + const ConvertToLLVMPattern &lowering, LLVMTypeConverter &typeConverter, + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) return lowering.matchFailure(); + } + + auto newOp = rewriter.create(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) return rewriter.eraseOp(op), lowering.matchSuccess(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + lowering.matchSuccess(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return lowering.matchSuccess(); +} + +// TODO(ntv): Patterns are too verbose due to the fact that we have 1 op (e.g. +// MaskRndScaleOp) and different possible target ops. It would be better to take +// a Functor so that all these conversions become 1-liners. +struct MaskRndScaleOpPS512Conversion : public ConvertToLLVMPattern { + explicit MaskRndScaleOpPS512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!cast(op) + .src() + .getType() + .cast() + .getElementType() + .isF32()) + return this->matchFailure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct MaskRndScaleOpPD512Conversion : public ConvertToLLVMPattern { + explicit MaskRndScaleOpPD512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskRndScaleOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!cast(op) + .src() + .getType() + .cast() + .getElementType() + .isF64()) + return this->matchFailure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct ScaleFOpPS512Conversion : public ConvertToLLVMPattern { + explicit ScaleFOpPS512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!cast(op) + .src() + .getType() + .cast() + .getElementType() + .isF32()) + return this->matchFailure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct ScaleFOpPD512Conversion : public ConvertToLLVMPattern { + explicit ScaleFOpPD512Conversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(MaskScaleFOp::getOperationName(), context, + typeConverter) {} + + PatternMatchResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!cast(op) + .src() + .getType() + .cast() + .getElementType() + .isF64()) + return this->matchFailure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +/// Populate the given list with patterns that convert from AVX512 to LLVM. +void mlir::populateAVX512ToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + patterns.insert(ctx, + converter); +} + +namespace { +struct LowerAVX512ToLLVMPass : public ModulePass { + void runOnModule() override; +}; +} // namespace + +void LowerAVX512ToLLVMPass::runOnModule() { + // Convert to the LLVM IR dialect. + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateAVX512ToLLVMConversionPatterns(converter, patterns); + populateVectorToLLVMConversionPatterns(converter, patterns); + populateStdToLLVMConversionPatterns(converter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + // TODO(ntv): Ops in the td file need to be split to avoid adding the + // following manually. + target.addIllegalOp(); + target.addDynamicallyLegalOp( + [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); + if (failed( + applyPartialConversion(getModule(), target, patterns, &converter))) { + signalPassFailure(); + } +} + +OpPassBase *mlir::createLowerAVX512ToLLVMPass() { + return new LowerAVX512ToLLVMPass(); +} + +static PassRegistration pass( + "convert-avx512-to-llvm", + "Lower the operations from the avx512 dialect into the LLVM dialect"); diff --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp @@ -0,0 +1,39 @@ +//===- AVX512Ops.cpp - MLIR AVX512 ops implementation ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AVX512 dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsX86.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/VectorOps/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" + +using namespace mlir; + +LLVM::AVX512Dialect::AVX512Dialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/AVX512/AVX512Ops.cpp.inc" + >(); +} + +namespace mlir { +namespace LLVM { +#define GET_OP_CLASSES +#include "mlir/Dialect/AVX512/AVX512Ops.cpp.inc" +} // namespace LLVM +} // namespace mlir + +// Static initialization for AVX512Ops dialect registration. +static DialectRegistration avx512Ops; diff --git a/mlir/lib/Target/LLVMIR/AVX512Intr.cpp b/mlir/lib/Target/LLVMIR/AVX512Intr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/AVX512Intr.cpp @@ -0,0 +1,49 @@ +//===- AVX512Intr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and AVX512 dialects +// and LLVM IR with AVX intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsX86.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "mlir/Dialect/AVX512/AVX512Dialect.h" + +using namespace mlir; + +namespace { +class Avx512ModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + + public: + using LLVM::ModuleTranslation::ModuleTranslation; + + protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/AVX512/AVX512Conversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr translateAvx512ModuleToLLVMIR(Operation *m) { + return LLVM::ModuleTranslation::translateModule(m); +} +} // end namespace + +static TranslateFromMLIRRegistration reg( + "avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) { + auto llvmModule = translateAvx512ModuleToLLVMIR(module); + if (!llvmModule) return failure(); + + llvmModule->print(output, nullptr); + return success(); + }); diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -convert-avx512-to-llvm | mlir-opt | FileCheck %s + +func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: avx512.mask.rndscale.ps.512 + %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32: vector<16xf32>, i16 + // CHECK: avx512.mask.rndscale.pd.512 + %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32: vector<8xf64>, i8 + + // CHECK: avx512.mask.scalef.ps.512 + %a0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>, i16 + // CHECK: avx512.mask.scalef.pd.512 + %a1 = avx512.mask.scalef %b, %b, %b, %i8, %i32: vector<8xf64>, i8 + + return %a0, %a1: vector<16xf32>, vector<8xf64> +} diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: avx512.mask.rndscale {{.*}}: vector<16xf32> + %0 = avx512.mask.rndscale %a, %i32, %a, %i16, %i32 : vector<16xf32>, i16 + // CHECK: avx512.mask.rndscale {{.*}}: vector<8xf64> + %1 = avx512.mask.rndscale %b, %i32, %b, %i8, %i32 : vector<8xf64>, i8 + return %0, %1: vector<16xf32>, vector<8xf64> +} + +func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16, %i8: i8) + -> (vector<16xf32>, vector<8xf64>) +{ + // CHECK: avx512.mask.scalef {{.*}}: vector<16xf32> + %0 = avx512.mask.scalef %a, %a, %a, %i16, %i32: vector<16xf32>, i16 + // CHECK: avx512.mask.scalef {{.*}}: vector<8xf64> + %1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>, i8 + return %0, %1: vector<16xf32>, vector<8xf64> +} diff --git a/mlir/test/Target/avx512.mlir b/mlir/test/Target/avx512.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/avx512.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --avx512-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define <16 x float> @LLVM_x86_avx512_mask_ps_512 +llvm.func @LLVM_x86_avx512_mask_ps_512(%a: !llvm<"<16 x float>">, + %b: !llvm.i32, + %c: !llvm.i16) + -> (!llvm<"<16 x float>">) +{ + // CHECK: call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> + %0 = "avx512.mask.rndscale.ps.512"(%a, %b, %a, %c, %b) : + (!llvm<"<16 x float>">, !llvm.i32, !llvm<"<16 x float>">, !llvm.i16, !llvm.i32) -> !llvm<"<16 x float>"> + // CHECK: call <16 x float> @llvm.x86.avx512.mask.scalef.ps.512(<16 x float> + %1 = "avx512.mask.scalef.ps.512"(%a, %a, %a, %c, %b) : + (!llvm<"<16 x float>">, !llvm<"<16 x float>">, !llvm<"<16 x float>">, !llvm.i16, !llvm.i32) -> !llvm<"<16 x float>"> + llvm.return %1: !llvm<"<16 x float>"> +} + +// CHECK-LABEL: define <8 x double> @LLVM_x86_avx512_mask_pd_512 +llvm.func @LLVM_x86_avx512_mask_pd_512(%a: !llvm<"<8 x double>">, + %b: !llvm.i32, + %c: !llvm.i8) + -> (!llvm<"<8 x double>">) +{ + // CHECK: call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> + %0 = "avx512.mask.rndscale.pd.512"(%a, %b, %a, %c, %b) : + (!llvm<"<8 x double>">, !llvm.i32, !llvm<"<8 x double>">, !llvm.i8, !llvm.i32) -> !llvm<"<8 x double>"> + // CHECK: call <8 x double> @llvm.x86.avx512.mask.scalef.pd.512(<8 x double> + %1 = "avx512.mask.scalef.pd.512"(%a, %a, %a, %c, %b) : + (!llvm<"<8 x double>">, !llvm<"<8 x double>">, !llvm<"<8 x double>">, !llvm.i8, !llvm.i32) -> !llvm<"<8 x double>"> + llvm.return %1: !llvm<"<8 x double>"> +}