diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h @@ -0,0 +1,23 @@ +//===- ArmSVEToLLVM.h - Conversion Patterns from ArmSVE to LLVM -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_ +#define MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_ + +namespace mlir { + +class LLVMTypeConverter; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the ArmSVE dialect to LLVM. +void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMSVETOLLVM_ARMSVETOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -396,8 +396,8 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AVX512, ArmNeon, SVE, etc.) in combination with the architectural-neutral - vector dialect lowering. + (AVX512, ArmNeon, ArmSVE, etc.) in combination with the + architectural-neutral vector dialect lowering. }]; let constructor = "mlir::createConvertVectorToLLVMPass()"; @@ -418,7 +418,11 @@ Option<"enableArmNeon", "enable-arm-neon", "bool", /*default=*/"false", "Enables the use of ArmNeon dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"enableArmSVE", "enable-arm-sve", + "bool", /*default=*/"false", + "Enables the use of ArmSVE dialect while lowering the vector " + "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,7 +23,7 @@ struct LowerVectorToLLVMOptions { LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), - enableArmNeon(false), enableAVX512(false) {} + enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -33,18 +33,23 @@ enableIndexOptimizations = b; return *this; } - LowerVectorToLLVMOptions &setEnableAVX512(bool b) { - enableAVX512 = b; - return *this; - } LowerVectorToLLVMOptions &setEnableArmNeon(bool b) { enableArmNeon = b; return *this; } + LowerVectorToLLVMOptions &setEnableArmSVE(bool b) { + enableArmSVE = b; + return *this; + } + LowerVectorToLLVMOptions &setEnableAVX512(bool b) { + enableAVX512 = b; + return *this; + } bool reassociateFPReductions; bool enableIndexOptimizations; bool enableArmNeon; + bool enableArmSVE; bool enableAVX512; }; diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -0,0 +1,276 @@ +//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef ARMSVE_OPS +#define ARMSVE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// ArmSVE dialect definition +//===----------------------------------------------------------------------===// + +def ArmSVE_Dialect : Dialect { + let name = "arm_sve"; + let cppNamespace = "::mlir::arm_sve"; + let summary = "Basic dialect to target Arm SVE architectures"; + let description = [{ + This dialect contains the definitions necessary to target Arm SVE scalable + vector operations, including a scalable vector type and intrinsics for + some Arm SVE instructions. + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSVE type definitions +//===----------------------------------------------------------------------===// + +def ArmSVE_ScalableVectorType : DialectType()">, + "scalable vector type">, + BuildableType<"$_builder.getType()"> { + let typeDescription = [{ + `arm_sve.vector` represents vectors that will be processed by a scalable + vector architecture. + }]; +} + +class ArmSVE_Type : TypeDef { } + +def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { + let mnemonic = "vector"; + + let summary = "Scalable vector type"; + + let description = [{ + A type representing scalable length SIMD vectors. Unlike fixed-length SIMD + vectors, whose size is constant and known at compile time, scalable + vectors' length is constant but determined by the specific hardware at + run time. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t", "Vector shape">:$shape, + "Type":$elementType + ); + + let printer = [{ + $_printer << "vector<"; + for (int64_t dim : getShape()) + $_printer << dim << 'x'; + $_printer << getElementType() << '>'; + }]; + + let parser = [{ + VectorType vector; + if ($_parser.parseType(vector)) + return Type(); + return get(ctxt, vector.getShape(), vector.getElementType()); + }]; + + let extraClassDeclaration = [{ + bool hasStaticShape() const { + return llvm::none_of(getShape(), ShapedType::isDynamic); + } + int64_t getNumElements() const { + assert(hasStaticShape() && + "cannot get element count of dynamic shaped type"); + ArrayRef shape = getShape(); + int64_t num = 1; + for (auto dim : shape) + num *= dim; + return num; + } + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSVE type traits +//===----------------------------------------------------------------------===// + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; + +class ScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", + "scalable vector">; + +class IsScalableVectorOfLengthPred allowedLengths> : + And<[IsScalableVectorTypePred, + Or().getNumElements() == }] + # allowedlength>)>]>; + +class ScalableVectorOfLength allowedLengths> : Type< + IsScalableVectorOfLengthPred, + " of length " # StrJoinInt.result>; + +class ScalableVectorOfLengthAndType allowedLengths, + list allowedTypes> : Type< + And<[ScalableVectorOf.predicate, + ScalableVectorOfLength.predicate]>, + ScalableVectorOf.description # + ScalableVectorOfLength.description>; + +//===----------------------------------------------------------------------===// +// ArmSVE op definitions +//===----------------------------------------------------------------------===// + +class ArmSVE_Op traits = []> : + Op {} + +def SdotOp : ArmSVE_Op<"sdot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + SDOT: Signed integer addition of dot product. + + This function maps to the SDOT instruction, and it takes signless integer + operands that the operation interprets as signed. It partitions the second + and third vector inputs into groups of four elements. They calculate the dot + product of each group (without loss of precision) and then add each result + to the overlapping element of the first vector input. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports either: + // (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + // (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc, + ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1, + ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + +def SmmlaOp : ArmSVE_Op<"smmla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + SMMLA: Signed integer matrix multiply-accumulate. + + This function maps to the SMMLA instruction, and it takes signless integer + operands that the operation interprets as signed. It partitions the inputs + into 128-bit quadwords, with the first input containing a row-by-row 2×2 + matrix of 32-bit integers, the second input containing a row-by-row 2×8 + matrix of 8-bit integers, and the third input containing a column-by-column + 8×2 matrix of 8-bit integers. For each quadword, they multiply the second + input matrix by the third input matrix using natural arithmetic and then add + the result to the first input using modular arithmetic. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + +def UdotOp : ArmSVE_Op<"udot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + UDOT: Unsigned integer addition of dot product. + + This function maps to the UDOT instruction, and it takes signless integer + operands that the operation interprets as unsigned. It partitions the second + and third vector inputs into groups of four elements. They calculate the dot + product of each group (without loss of precision) and then add each result + to the overlapping element of the first vector input. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports either: + // (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + // (vector<8xi16>. vector<8xi16>) -> (vector<2xi64>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$acc, + ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src1, + ScalableVectorOfLengthAndType<[16, 8], [I8, I16]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4, 2], [I32, I64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + +def UmmlaOp : ArmSVE_Op<"ummla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + UMMLA: Unsigned integer matrix multiply-accumulate. + + This function maps to the UMMLA instruction, and it takes signless integer + operands that the operation interprets as unsigned. It partitions the inputs + into 128-bit quadwords, with the first input containing a row-by-row 2×2 + matrix of 32-bit integers, the second input containing a row-by-row 2×8 + matrix of 8-bit integers, and the third input containing a column-by-column + 8×2 matrix of 8-bit integers. For each quadword, they multiply the second + input matrix by the third input matrix using natural arithmetic and then add + the result to the first input using modular arithmetic. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>) + let arguments = (ins + ScalableVectorOfLengthAndType<[4], [I32]>:$acc, + ScalableVectorOfLengthAndType<[16], [I8]>:$src1, + ScalableVectorOfLengthAndType<[16], [I8]>:$src2 + ); + let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)"; +} + +def VectorScaleOp : ArmSVE_Op<"vector_scale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The vector_scale op returns the scale of the scalable vectors, a positive + integer value that is constant at runtime but unknown at compile time. + The scale of the vector indicates the multiplicity of the vectors and + vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to + vector_scale consecutive vector<4xi32>; and an operation on an + !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale + times, once on each <4xi32> segment of the scalable vector. The vector_scale + op can be used to calculate the step in vector-length agnostic (VLA) loops. + }]; + let results = (outs Index:$res); + let assemblyFormat = + "attr-dict `:` type($res)"; +} + +#endif // ARMSVE_OPS diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVEDialect.h @@ -0,0 +1,29 @@ +//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for ArmSVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H +#define MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVETypes.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVE.h.inc" + +#endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(ArmSVE arm_sve ArmSVE) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Affine) add_subdirectory(Async) add_subdirectory(ArmNeon) +add_subdirectory(ArmSVE) add_subdirectory(AVX512) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -37,3 +37,9 @@ set(LLVM_TARGET_DEFINITIONS LLVMArmNeon.td) mlir_tablegen(LLVMArmNeonConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMArmNeonConversionsIncGen) + +add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE) +add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/) +set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td) +mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td @@ -0,0 +1,70 @@ +//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMArmSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_ARMSVE_OPS +#define LLVMIR_ARMSVE_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMArmSVE dialect definition +//===----------------------------------------------------------------------===// + +def LLVMArmSVE_Dialect : Dialect { + let name = "llvm_arm_sve"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +class LLVMArmSVE_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +def LLVM_aarch64_arm_sve_ummla : + LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_arm_sve_smmla : + LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_arm_sve_sdot : + LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_arm_sve_udot : + LLVMArmSVE_IntrBinaryOverloadedOp<"udot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_vector_scale : + LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + +#endif // ARMSVE_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h @@ -0,0 +1,24 @@ +//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMArmSVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -17,10 +17,12 @@ #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -54,6 +56,7 @@ LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, LLVM::LLVMArmNeonDialect, + LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, scf::SCFDialect, omp::OpenMPDialect, @@ -62,6 +65,7 @@ quant::QuantizationDialect, spirv::SPIRVDialect, StandardOpsDialect, + arm_sve::ArmSVEDialect, vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -24,6 +24,7 @@ void registerToROCDLIRTranslation(); void registerArmNeonToLLVMIRTranslation(); void registerAVX512ToLLVMIRTranslation(); +void registerArmSVEToLLVMIRTranslation(); // This function should be called before creating any MLIRContext if one // expects all the possible translations to be made available to the context @@ -38,6 +39,7 @@ registerToROCDLIRTranslation(); registerArmNeonToLLVMIRTranslation(); registerAVX512ToLLVMIRTranslation(); + registerArmSVEToLLVMIRTranslation(); return true; }(); (void)initOnce; diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp @@ -0,0 +1,75 @@ +//===- ArmSVEToLLVM.cpp - Convert ArmSVE to the LLVM dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::arm_sve; +using namespace mlir::vector; + +using SdotOpLowering = + OneToOneConvertToLLVMPattern; + +using SmmlaOpLowering = + OneToOneConvertToLLVMPattern; + +using UdotOpLowering = + OneToOneConvertToLLVMPattern; + +using UmmlaOpLowering = + OneToOneConvertToLLVMPattern; + +using VectorScaleOpLowering = + OneToOneConvertToLLVMPattern; + +// Extract an LLVM IR type from the LLVM IR dialect type. +static LLVM::LLVMType unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; +} + +static Optional +convertScalableVectorTypeToLLVM(ScalableVectorType svType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(svType.getElementType())); + if (!elementType) + return {}; + + auto sVectorType = + LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); + return sVectorType; +} + +/// Populate the given list with patterns that convert from ArmSVE to LLVM. +void mlir::populateArmSVEToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + converter.addConversion([&converter](ScalableVectorType svType) { + return convertScalableVectorTypeToLLVM(svType, converter); + }); + // clang-format off + patterns.insert(converter); + // clang-format on +} diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArmSVEToLLVM + ArmSVEToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmSVE + MLIRLLVMArmSVE + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -26,6 +26,7 @@ namespace LLVM { class LLVMArmNeonDialect; +class LLVMArmSVEDialect; class LLVMAVX512Dialect; class LLVMDialect; } // end namespace LLVM diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -19,6 +19,9 @@ MLIRAVX512ToLLVM MLIRLLVMArmNeon MLIRLLVMAVX512 + MLIRArmSVE + MLIRArmSVEToLLVM + MLIRLLVMArmSVE MLIRLLVMIR MLIRStandardToLLVM MLIRTargetLLVMIRModuleTranslation diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -12,12 +12,15 @@ #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h" +#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -32,6 +35,7 @@ this->reassociateFPReductions = options.reassociateFPReductions; this->enableIndexOptimizations = options.enableIndexOptimizations; this->enableArmNeon = options.enableArmNeon; + this->enableArmSVE = options.enableArmSVE; this->enableAVX512 = options.enableAVX512; } // Override explicitly to allow conditional dialect dependence. @@ -39,6 +43,8 @@ registry.insert(); if (enableArmNeon) registry.insert(); + if (enableArmSVE) + registry.insert(); if (enableAVX512) registry.insert(); } @@ -73,6 +79,11 @@ target.addIllegalDialect(); populateArmNeonToLLVMConversionPatterns(converter, patterns); } + if (enableArmSVE) { + target.addLegalDialect(); + target.addIllegalDialect(); + populateArmSVEToLLVMConversionPatterns(converter, patterns); + } if (enableAVX512) { target.addLegalDialect(); target.addIllegalDialect(); diff --git a/mlir/lib/Dialect/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRArmSVE + IR/ArmSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE + + DEPENDS + MLIRArmSVEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -0,0 +1,57 @@ +//===- ArmSVEDialect.cpp - MLIR ArmSVE dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ArmSVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +void arm_sve::ArmSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// ScalableVectorType +//===----------------------------------------------------------------------===// + +Type arm_sve::ArmSVEDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + auto genType = generatedTypeParser(getContext(), parser, "vector"); + if (genType != Type()) + return genType; + parser.emitError(typeLoc, "unknown type in ArmSVE dialect"); + return Type(); +} + +void arm_sve::ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'arm_sve' type kind"); +} diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Affine) add_subdirectory(ArmNeon) +add_subdirectory(ArmSVE) add_subdirectory(Async) add_subdirectory(AVX512) add_subdirectory(GPU) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -70,6 +70,27 @@ MLIRSideEffectInterfaces ) +add_mlir_dialect_library(MLIRLLVMArmSVE + IR/LLVMArmSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMArmSVEIncGen + MLIRLLVMArmSVEConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp @@ -0,0 +1,31 @@ +//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMArmSVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsAArch64.h" + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMArmSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -74,6 +74,25 @@ MLIRTargetLLVMIRModuleTranslation ) +add_mlir_translation_library(MLIRTargetArmSVE + LLVMIR/LLVMArmSVEIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + + DEPENDS + MLIRLLVMArmSVEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMArmSVE + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) + add_mlir_translation_library(MLIRTargetNVVMIR LLVMIR/ConvertToNVVMIR.cpp diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp @@ -0,0 +1,63 @@ +//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and ArmSVE dialects +// and LLVM IR with Arm SVE intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; + +namespace { +class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; +} // end namespace + +static std::unique_ptr +translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); +} + +namespace mlir { +void registerArmSVEToLLVMIRTranslation() { + TranslateFromMLIRRegistration reg( + "arm-sve-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMArmSVEModuleToLLVMIR( + module, llvmContext, "LLVMDialectModule"); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} +} // namespace mlir diff --git a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s + +func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> { + // CHECK: llvm_arm_sve.sdot + %0 = arm_sve.sdot %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> { + // CHECK: llvm_arm_sve.smmla + %0 = arm_sve.smmla %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> { + // CHECK: llvm_arm_sve.udot + %0 = arm_sve.udot %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> { + // CHECK: llvm_arm_sve.ummla + %0 = arm_sve.ummla %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @get_vector_scale() -> index { + // CHECK: llvm_arm_sve.vscale + %0 = arm_sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 + %0 = arm_sve.sdot %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 + %0 = arm_sve.smmla %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32 + %0 = arm_sve.udot %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> { + // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> to !arm_sve.vector<4xi3 + %0 = arm_sve.ummla %c, %a, %b : + !arm_sve.vector<16xi8> to !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @get_vector_scale() -> index { + // CHECK: arm_sve.vector_scale : index + %0 = arm_sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Target/arm-sve.mlir b/mlir/test/Target/arm-sve.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/arm-sve.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define @arm_sve_sdot +llvm.func @arm_sve_sdot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_smmla +llvm.func @arm_sve_smmla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_udot +llvm.func @arm_sve_udot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_ummla +llvm.func @arm_sve_ummla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define i64 @get_vector_scale() +llvm.func @get_vector_scale() -> !llvm.i64 { + // CHECK: call i64 @llvm.vscale.i64() + %0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 +}