diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -26,6 +26,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -379,6 +379,17 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// SVEToLLVM +//===----------------------------------------------------------------------===// + +def ConvertSVEToLLVM : Pass<"convert-sve-to-llvm", "ModuleOp"> { + let summary = "Convert the operations from the sve dialect into the LLVM " + "dialect"; + let constructor = "mlir::createConvertSVEToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMSVEDialect"]; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h b/mlir/include/mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h @@ -0,0 +1,40 @@ +//===- ConvertSVEToLLVM.h - Conversion Patterns from SVE to LLVM ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SVETOLLVM_CONVERTSVETOLLVM_H_ +#define MLIR_CONVERSION_SVETOLLVM_CONVERTSVETOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include + +namespace mlir { +class LLVMTypeConverter; +class ModuleOp; +template +class OperationPass; +class OwningRewritePatternList; + +//===----------------------------------------------------------------------===// +// SVE Scalar Vector Type Conversion +//===----------------------------------------------------------------------===// + +class SVETypeConverter : public LLVMTypeConverter { +public: + explicit SVETypeConverter(MLIRContext *ctx); +}; + +/// Collect a set of patterns to convert from the SVE dialect to LLVM. +void populateSVEToLLVMConversionPatterns(SVETypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert SVE operations to the LLVMIR dialect. +std::unique_ptr> createConvertSVEToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SVETOLLVM_CONVERTSVETOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,5 +13,6 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(SVE) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -30,3 +30,9 @@ set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) + +add_mlir_dialect(LLVMSVE llvm_sve LLVMSVE) + +set(LLVM_TARGET_DEFINITIONS LLVMSVE.td) +mlir_tablegen(LLVMSVEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMSVEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVE.td @@ -0,0 +1,73 @@ +//===-- LLVMSVE.td - LLVMSVE dialect op definitions --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_SVE_OPS +#define LLVMIR_SVE_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMSVE dialect definition +//===----------------------------------------------------------------------===// + +def LLVMSVE_Dialect : Dialect { + let name = "llvm_sve"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM SVE intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +class LLVMSVE_IntrOp traits = []> : + LLVM_IntrOpBase; + +class LLVMSVE_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +def LLVM_aarch64_sve_ummla : + LLVMSVE_IntrBinaryOverloadedOp<"ummla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_smmla : + LLVMSVE_IntrBinaryOverloadedOp<"smmla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_sdot : + LLVMSVE_IntrBinaryOverloadedOp<"sdot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_udot : + LLVMSVE_IntrBinaryOverloadedOp<"udot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_vector_scale : + LLVMSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + +#endif // SVE_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVEDialect.h @@ -0,0 +1,24 @@ +//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMSVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMSVEDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMSVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMSVE.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMSVEDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/SVE/CMakeLists.txt b/mlir/include/mlir/Dialect/SVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SVE/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(SVE sve SVE) diff --git a/mlir/include/mlir/Dialect/SVE/SVE.td b/mlir/include/mlir/Dialect/SVE/SVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SVE/SVE.td @@ -0,0 +1,255 @@ +//===-- SVE.td - SVE dialect operation definitions ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the SVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef SVE_OPS +#define SVE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// SVE dialect definition +//===----------------------------------------------------------------------===// + +def SVE_Dialect : Dialect { + let name = "sve"; + let cppNamespace = "::mlir::sve"; + let summary = "Basic dialect to target Arm SVE architectures"; + let description = [{ + This dialect contains the definitions necessary to target SVE scalable + vector operations, including a scalable vector type and intrinsics for + some SVE instructions. + }]; +} + +//===----------------------------------------------------------------------===// +// SVE type definitions +//===----------------------------------------------------------------------===// + +def SVE_ScalableVectorType : DialectType()">, + "scalable vector type">, + BuildableType<"$_builder.getType()"> { + let typeDescription = [{ + `sve.vector` represents vectors that will be processed by a scalable + vector architecture. + }]; +} + +class SVE_Type : TypeDef { } + +def ScalableVectorType : SVE_Type<"ScalableVector"> { + let mnemonic = "vector"; + + let summary = "Scalable vector type"; + + let description = [{ + A type representing scalable length SIMD vectors. Unlike fixed-length SIMD + vectors, whose size is constant and known at compile time, scalable + vectors' length is constant but determined by the specific hardware at + run time. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t", "Vector shape">:$shape, + "Type":$elementType + ); + + let printer = [{ + $_printer << "vector<"; + for (int64_t dim : getShape()) + $_printer << dim << 'x'; + $_printer << getElementType() << '>'; + }]; + + let parser = [{ + VectorType vector; + if ($_parser.parseType(vector)) + return Type(); + return get(ctxt, vector.getShape(), vector.getElementType()); + }]; +} + +//===----------------------------------------------------------------------===// +// SVE type traits +//===----------------------------------------------------------------------===// + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::sve::ScalableVectorType>()">; + +class ScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::sve::ScalableVectorType>().getElementType()", + "scalable vector">; + +//===----------------------------------------------------------------------===// +// SVE op definitions +//===----------------------------------------------------------------------===// + +class SVE_Op traits = []> : + Op {} + +def UmmlaOp : SVE_Op<"ummla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + The ummla op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.ummla` instruction. + + UMMLA: Unsigned integer matrix multiply-accumulate. + + The unsigned integer matrix multiply-accumulate instruction multiplies the + 2×8 matrix of unsigned 8-bit integer values held in each <16xui8> segment + of the `$src1` vector by the 8×2 matrix of unsigned 8-bit integer values + in the corresponding <16xui8> segment of the `$src2` vector. The resulting + 2×2 widened 32-bit integer matrix product is then destructively added to the + 32-bit integer matrix accumulator held in the corresponding <4xui32> segment + of the `$acc` vector. This is equivalent to performing an 8-way dot product + per destination element. + + Elements of the `$src1` matrix are stored in row-major order, in blocks of + 2x8 elements per segment. Elements of the `$src2` matrix are stored in + column-major order, in blocks of 8x2 elements per segment. The elements of + the `$acc` matrix are stored in row-major order, in blocks of 2x2 elements + per segment. + }]; + // Supports vector<16xi8>. + let arguments = (ins + ScalableVectorOf<[UI32]>:$acc, + ScalableVectorOf<[UI8]>:$src1, + ScalableVectorOf<[UI8]>:$src2 + ); + let results = (outs ScalableVectorOf<[UI32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SmmlaOp : SVE_Op<"smmla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + The smmla op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.smmla` instruction. + + SMMLA: Signed integer matrix multiply-accumulate. + + The signed integer matrix multiply-accumulate instruction multiplies the + 2×8 matrix of signed 8-bit integer values held in each <16xsi8> segment + of the `$src1` vector by the 8×2 matrix of signed 8-bit integer values + in the corresponding <16xsi8> segment of the `$src2` vector. The resulting + 2×2 widened 32-bit integer matrix product is then destructively added to the + 32-bit integer matrix accumulator held in the corresponding <4xsi32> segment + of the `$acc` vector. This is equivalent to performing an 8-way dot product + per destination element. + + Elements of the `$src1` matrix are stored in row-major order, in blocks of + 2x8 elements per segment. Elements of the `$src2` matrix are stored in + column-major order, in blocks of 8x2 elements per segment. The elements of + the `$acc` matrix are stored in row-major order, in blocks of 2x2 elements + per segment. + }]; + // Supports vector<16 x si8>. + let arguments = (ins + ScalableVectorOf<[SI32]>:$acc, + ScalableVectorOf<[SI8]>:$src1, + ScalableVectorOf<[SI8]>:$src2 + ); + let results = (outs ScalableVectorOf<[SI32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SdotOp : SVE_Op<"sdot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + The sdot op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.sdot` + + SDOT: Signed integer dot product. + + The signed integer dot product instruction computes the dot product of a + group of four signed 8-bit or 16-bit integer values held in each 32-bit or + 64-bit element of the `$src1` vector multiplied by a group of four signed + 8-bit or 16-bit integer values in the corresponding 32-bit or 64-bit + element of the `$src2` vector, and then destructively adds the widened dot + product to the corresponding 32-bit or 64-bit element of the `$acc` vector. + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[SI32, SI64]>:$acc, + ScalableVectorOf<[SI8, SI16]>:$src1, + ScalableVectorOf<[SI8, SI16]>:$src2 + ); + let results = (outs ScalableVectorOf<[SI32, SI64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def UdotOp : SVE_Op<"udot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + The udot op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.udot` + + UDOT: Unsigned integer dot product. + + The unsigned integer dot product instruction computes the dot product of a + group of four unsigned 8-bit or 16-bit integer values held in each 32-bit + or 64-bit element of the `$src1` vector multiplied by a group of four + unsigned 8-bit or 16-bit integer values in the corresponding 32-bit or + 64-bit element of the `$src2` vector, and then destructively adds the + widened dot product to the corresponding 32-bit or 64-bit element of the + `$acc` vector. + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[UI32, UI64]>:$acc, + ScalableVectorOf<[UI8, UI16]>:$src1, + ScalableVectorOf<[UI8, UI16]>:$src2 + ); + let results = (outs ScalableVectorOf<[UI32, UI64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def VectorScaleOp : SVE_Op<"vector_scale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The vector_scale op returns the scale of the scalable vectors, a positive + integer value that is constant at runtime but unknown at compile time. + The scale of the vector indicates the multiplicity of the vectors and + vector operations. I.e.: an !sve.vector<4xi32> is equivalent to + vector_scale consecutive vector<4xi32>; and an operation on an + !sve.vector<4xi32> is equivalent to performing that operation vector_scale + times, once on each <4xi32> segment of the scalable vector. The vector_scale + op can be used to calculate the step in vector-length agnostic (VLA) loops. + }]; + let results = (outs Index:$res); + let assemblyFormat = + "attr-dict `:` type($res)"; +} + +#endif // SVE_OPS diff --git a/mlir/include/mlir/Dialect/SVE/SVEDialect.h b/mlir/include/mlir/Dialect/SVE/SVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SVE/SVEDialect.h @@ -0,0 +1,29 @@ +//===- SVEDialect.h - MLIR Dialect for SVE ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for SVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SVE_SVEDIALECT_H_ +#define MLIR_DIALECT_SVE_SVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/SVE/SVEDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/SVE/SVETypes.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/SVE/SVE.h.inc" + +#endif // MLIR_DIALECT_SVE_SVEDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -31,6 +32,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SVE/SVEDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -49,6 +51,7 @@ gpu::GPUDialect, LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, + LLVM::LLVMSVEDialect, linalg::LinalgDialect, scf::SCFDialect, omp::OpenMPDialect, @@ -57,6 +60,7 @@ quant::QuantizationDialect, spirv::SPIRVDialect, StandardOpsDialect, + sve::SVEDialect, vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -23,6 +23,7 @@ void registerToNVVMIRTranslation(); void registerToROCDLIRTranslation(); void registerAVX512ToLLVMIRTranslation(); +void registerSVEToLLVMIRTranslation(); // This function should be called before creating any MLIRContext if one // expects all the possible translations to be made available to the context @@ -36,6 +37,7 @@ registerToNVVMIRTranslation(); registerToROCDLIRTranslation(); registerAVX512ToLLVMIRTranslation(); + registerSVEToLLVMIRTranslation(); return true; }(); (void)initOnce; diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(SVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -27,6 +27,7 @@ namespace LLVM { class LLVMDialect; class LLVMAVX512Dialect; +class LLVMSVEDialect; } // end namespace LLVM namespace NVVM { diff --git a/mlir/lib/Conversion/SVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SVEToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SVEToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRSVEToLLVM + ConvertSVEToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SVEToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRSVE + MLIRLLVMSVE + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/SVEToLLVM/ConvertSVEToLLVM.cpp b/mlir/lib/Conversion/SVEToLLVM/ConvertSVEToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SVEToLLVM/ConvertSVEToLLVM.cpp @@ -0,0 +1,256 @@ +//===- ConvertSVEToLLVM.cpp - Convert SVE to the LLVM dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" +#include "mlir/Dialect/SVE/SVEDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::sve; + +template +static Type getSrc1VectorElementType(OpTy op) { + return op.src1() + .getType() + .template cast() + .getElementType(); +} +template +static Type getSrc2VectorElementType(OpTy op) { + return op.src2() + .getType() + .template cast() + .getElementType(); +} +template +static Type getAccVectorElementType(OpTy op) { + return op.acc() + .getType() + .template cast() + .getElementType(); +} + +/// Basic lowering implementation for one-to-one rewriting from SVE Ops to +/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass +/// operands as is, preserve attributes. +template +static LogicalResult +matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, + LLVMTypeConverter &typeConverter, Operation *op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + auto newOp = rewriter.create(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), success(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + success(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return success(); +} + +namespace { + +struct UmmlaOpConversion : public ConvertToLLVMPattern { + explicit UmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct SmmlaOpConversion : public ConvertToLLVMPattern { + explicit SmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct SdotOpConversion : public ConvertToLLVMPattern { + explicit SdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct UdotOpConversion : public ConvertToLLVMPattern { + explicit UdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct VectorScaleOpConversion : public ConvertToLLVMPattern { + explicit VectorScaleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(VectorScaleOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +} // namespace + +/// Populate the given list with patterns that convert from SVE to LLVM. +void mlir::populateSVEToLLVMConversionPatterns( + SVETypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + // clang-format off + patterns.insert(ctx, converter); + // clang-format on +} + +namespace { +struct ConvertSVEToLLVMPass + : public ConvertSVEToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +// Extract an LLVM IR type from the LLVM IR dialect type. +static LLVM::LLVMType unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; +} + +static Optional convertScalableVectorType(ScalableVectorType svType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(svType.getElementType())); + if (!elementType) + return {}; + + auto sVectorType = + LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); + return sVectorType; +} + +SVETypeConverter::SVETypeConverter(MLIRContext *ctx) : LLVMTypeConverter(ctx) { + addConversion([this](ScalableVectorType svType) { + return convertScalableVectorType(svType, *this); + }); +} + +void ConvertSVEToLLVMPass::runOnOperation() { + OwningRewritePatternList patterns; + SVETypeConverter sveConverter(&getContext()); + populateSVEToLLVMConversionPatterns(sveConverter, patterns); + populateVectorToLLVMConversionPatterns(sveConverter, patterns); + populateStdToLLVMConversionPatterns(sveConverter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr> mlir::createConvertSVEToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(SVE) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -49,6 +49,27 @@ MLIRSideEffectInterfaces ) +add_mlir_dialect_library(MLIRLLVMSVE + IR/LLVMSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMSVEIncGen + MLIRLLVMSVEConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMSVEDialect.cpp @@ -0,0 +1,31 @@ +//===- LLVMSVEDialect.cpp - MLIR LLVMSVE ops implementation ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMSVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsAArch64.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMSVE.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMSVE.cpp.inc" diff --git a/mlir/lib/Dialect/SVE/CMakeLists.txt b/mlir/lib/Dialect/SVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SVE/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRSVE + IR/SVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SVE + + DEPENDS + MLIRSVEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + MLIRVectorToLLVM + ) diff --git a/mlir/lib/Dialect/SVE/IR/SVEDialect.cpp b/mlir/lib/Dialect/SVE/IR/SVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SVE/IR/SVEDialect.cpp @@ -0,0 +1,63 @@ +//===- SVEDialect.cpp - MLIR SVE dialect implementation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SVE/SVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +void sve::SVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/SVE/SVE.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/SVE/SVETypes.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/SVE/SVE.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/SVE/SVETypes.cpp.inc" + +namespace mlir { +namespace sve { + +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +Type SVEDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + auto genType = generatedTypeParser(getContext(), parser, "vector"); + if (genType != Type()) + return genType; + parser.emitError(typeLoc, "unknown type in SVE dialect"); + return Type(); +} + +void SVEDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'sve' type kind"); +} + +} // namespace sve +} // namespace mlir diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -94,3 +94,22 @@ MLIRROCDLIR MLIRTargetLLVMIRModuleTranslation ) + +add_mlir_translation_library(MLIRTargetSVE + LLVMIR/LLVMSVEIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + + DEPENDS + MLIRLLVMSVEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMSVE + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/mlir/lib/Target/LLVMIR/LLVMSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMSVEIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMSVEIntr.cpp @@ -0,0 +1,63 @@ +//===- LLVMSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and SVE dialects and +// LLVM IR with SVE intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; + +namespace { +class LLVMSVEModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMSVEConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr +translateLLVMSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); +} +} // end namespace + +namespace mlir { +void registerSVEToLLVMIRTranslation() { + TranslateFromMLIRRegistration reg( + "sve-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMSVEModuleToLLVMIR(module, llvmContext, + "LLVMDialectModule"); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} +} // namespace mlir diff --git a/mlir/test/Conversion/SVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/SVEToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SVEToLLVM/convert-to-llvm.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -convert-sve-to-llvm | mlir-opt | FileCheck %s + +func @sve_sdot(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) + -> !sve.vector<4xsi32> +{ + // CHECK: llvm_sve.sdot + %0 = sve.sdot %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_smmla(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) + -> !sve.vector<4xsi32> +{ + // CHECK: llvm_sve.smmla + %0 = sve.smmla %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_udot(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) + -> !sve.vector<4xui32> +{ + // CHECK: llvm_sve.udot + %0 = sve.udot %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @sve_ummla(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) + -> !sve.vector<4xui32> +{ + // CHECK: llvm_sve.ummla + %0 = sve.ummla %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @get_vector_scale() -> index +{ + // CHECK: llvm_sve.vscale + %0 = sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Dialect/SVE/roundtrip.mlir b/mlir/test/Dialect/SVE/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SVE/roundtrip.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @sve_sdot(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) -> !sve.vector<4xsi32> +{ + // CHECK: sve.sdot {{.*}}: !sve.vector<16xsi8> -> !sve.vector<4xsi32> + %0 = sve.sdot %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_smmla(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) -> !sve.vector<4xsi32> +{ + // CHECK: sve.smmla {{.*}}: !sve.vector<16xsi8> -> !sve.vector<4xsi32> + %0 = sve.smmla %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_udot(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) -> !sve.vector<4xui32> +{ + // CHECK: sve.udot {{.*}}: !sve.vector<16xui8> -> !sve.vector<4xui32> + %0 = sve.udot %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @sve_ummla(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) -> !sve.vector<4xui32> +{ + // CHECK: sve.ummla {{.*}}: !sve.vector<16xui8> -> !sve.vector<4xui32> + %0 = sve.ummla %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @get_vector_scale() -> index +{ + // CHECK: sve.vector_scale : index + %0 = sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Target/sve.mlir b/mlir/test/Target/sve.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/sve.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --sve-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define @sve_sdot +llvm.func @sve_sdot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @sve_smmla +llvm.func @sve_smmla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @sve_udot +llvm.func @sve_udot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @sve_ummla +llvm.func @sve_ummla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define i64 @get_vector_scale() +llvm.func @get_vector_scale() -> !llvm.i64 { + // CHECK: call i64 @llvm.vscale.i64() + %0 = "llvm_sve.vscale"() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 +}