diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -26,6 +26,7 @@ #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -379,6 +379,17 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// SVEToLLVM +//===----------------------------------------------------------------------===// + +def ConvertSVEToLLVM : Pass<"convert-sve-to-llvm", "ModuleOp"> { + let summary = "Convert the operations from the sve dialect into the LLVM " + "dialect"; + let constructor = "mlir::createConvertSVEToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMSVEDialect"]; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h b/mlir/include/mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h @@ -0,0 +1,40 @@ +//===- ConvertSVEToLLVM.h - Conversion Patterns from SVE to LLVM ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_EDGE_CONVERSION_SVETOLLVM_CONVERTSVETOLLVM_H_ +#define MLIR_EDGE_CONVERSION_SVETOLLVM_CONVERTSVETOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include + +namespace mlir { +class LLVMTypeConverter; +class ModuleOp; +template +class OperationPass; +class OwningRewritePatternList; + +//===----------------------------------------------------------------------===// +// SVE Scalar Vector Type Conversion +//===----------------------------------------------------------------------===// + +class SVETypeConverter : public LLVMTypeConverter { +public: + explicit SVETypeConverter(MLIRContext *ctx); +}; + +/// Collect a set of patterns to convert from the SVE dialect to LLVM. +void populateSVEToLLVMConversionPatterns(SVETypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert SVE operations to the LLVMIR dialect. +std::unique_ptr> createConvertSVEToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_EDGE_CONVERSION_SVETOLLVM_CONVERTSVETOLLVM_H_ diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,5 +13,6 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(SVE) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -30,3 +30,9 @@ set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) + +add_mlir_dialect(LLVMSVE llvm_sve LLVMSVE) + +set(LLVM_TARGET_DEFINITIONS LLVMSVE.td) +mlir_tablegen(LLVMSVEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMSVEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVE.td @@ -0,0 +1,77 @@ +//===-- LLVMSVE.td - LLVMSVE dialect op definitions --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_SVE_OPS +#define LLVMIR_SVE_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMSVE dialect definition +//===----------------------------------------------------------------------===// + +def LLVMSVE_Dialect : Dialect { + let name = "llvm_sve"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM SVE intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +class LLVMSVE_IntrOp traits = []> : + LLVM_IntrOpBase; + +class LLVMSVE_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +//def LLVM_sve_fadd : +// LLVMSVE_InstrOp<"fadd">, +// Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_ummla : + LLVMSVE_IntrBinaryOverloadedOp<"ummla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_smmla : + LLVMSVE_IntrBinaryOverloadedOp<"smmla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_sdot : + LLVMSVE_IntrBinaryOverloadedOp<"sdot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_udot : + LLVMSVE_IntrBinaryOverloadedOp<"udot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_vscale : + LLVMSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + +#endif // SVE_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMSVEDialect.h @@ -0,0 +1,24 @@ +//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMSVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMSVEDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMSVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMSVE.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMSVEDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/SVE/CMakeLists.txt b/mlir/include/mlir/Dialect/SVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SVE/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(SVE sve SVE) diff --git a/mlir/include/mlir/Dialect/SVE/SVE.td b/mlir/include/mlir/Dialect/SVE/SVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SVE/SVE.td @@ -0,0 +1,163 @@ +//===-- SVE.td - SVE dialect operation definitions ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the SVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef SVE_OPS +#define SVE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// SVE dialect definition +//===----------------------------------------------------------------------===// + +def SVE_Dialect : Dialect { + let name = "sve"; + let cppNamespace = "::mlir::sve"; + let summary = "Basic dialect to target Arm SVE architectures"; + let description = [{ + This dialect contains the definitions necessary to target SVE scalable + vector operations, including a scalable vector type and intrinsics for + some SVE instructions. + }]; +} + +//===----------------------------------------------------------------------===// +// SVE scalable vector type definition +//===----------------------------------------------------------------------===// + +def SVE_ScalableVectorType : DialectType()">, + "scalable vector type">, + BuildableType<"$_builder.getType()"> { + let typeDescription = [{ + `sve.vector` represents vectors that will be processed by a scalable + vector architecture. + }]; +} + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::sve::ScalableVectorType>()">; + +class ScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::sve::ScalableVectorType>().getElementType()", + "scalable vector">; + +//===----------------------------------------------------------------------===// +// SVE op definitions +//===----------------------------------------------------------------------===// + +class SVE_Op traits = []> : + Op {} + +def UmmlaOp : SVE_Op<"ummla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + The ummla op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.ummla` instruction. + TODO: Add documentation for ummla. + }]; + // Supports vector<16xi8>. + let arguments = (ins + ScalableVectorOf<[UI32]>:$acc, + ScalableVectorOf<[UI8]>:$src1, + ScalableVectorOf<[UI8]>:$src2 + ); + let results = (outs ScalableVectorOf<[UI32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SmmlaOp : SVE_Op<"smmla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + The ummla op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.smmla` instruction. + TODO: Add documentation for smmla. + }]; + // Supports vector<16 x i8>. + let arguments = (ins + ScalableVectorOf<[SI32]>:$acc, + ScalableVectorOf<[SI8]>:$src1, + ScalableVectorOf<[SI8]>:$src2 + ); + let results = (outs ScalableVectorOf<[SI32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SdotOp : SVE_Op<"sdot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + The ummla op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.sdot` + TODO: Add documentation for sdot. + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[SI32, SI64]>:$acc, + ScalableVectorOf<[SI8, SI16]>:$src1, + ScalableVectorOf<[SI8, SI16]>:$src2 + ); + let results = (outs ScalableVectorOf<[SI32, SI64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def UdotOp : SVE_Op<"udot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + The ummla op is an SVE specific op that can lower to the proper LLVMSVE + operation: `llvm.aarch64.sve.udot` + TODO: Add documentation for sdot. + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[UI32, UI64]>:$acc, + ScalableVectorOf<[UI8, UI16]>:$src1, + ScalableVectorOf<[UI8, UI16]>:$src2 + ); + let results = (outs ScalableVectorOf<[UI32, UI64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def VscaleOp : SVE_Op<"vscale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The vscale op is a scalable vector architecture intrinsic that will + lower to `llvm.vscale`. vscale returns the scaling factor for + scalable vectors. + }]; + let results = (outs Index:$res); + let assemblyFormat = + "attr-dict `:` type($res)"; +} + +#endif // SVE_OPS diff --git a/mlir/include/mlir/Dialect/SVE/SVEDialect.h b/mlir/include/mlir/Dialect/SVE/SVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SVE/SVEDialect.h @@ -0,0 +1,100 @@ +//===- SVEDialect.h - MLIR Dialect for SVE ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for SVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SVE_SVEDIALECT_H_ +#define MLIR_DIALECT_SVE_SVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +namespace sve { + +///===----------------------------------------------------------------------===// +///=== ScalableVectorType +///===----------------------------------------------------------------------===// + +struct ScalableVectorTypeStorage : public TypeStorage { + ScalableVectorTypeStorage(unsigned shapeSize, Type elementTy, + const int64_t *shapeElements) + : shapeElements(shapeElements), shapeSize(shapeSize), + elementType(elementTy) {} + + /// Hash key for uniquing + using KeyTy = std::pair, Type>; + bool operator==(const KeyTy &key) const { + return key == KeyTy(getShape(), elementType); + } + + /// Construction. + static ScalableVectorTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + // Copy the shape into the bump pointer. + ArrayRef shape = allocator.copyInto(key.first); + + // Initialize the memory using placement new. + return new (allocator.allocate()) + ScalableVectorTypeStorage(shape.size(), key.second, shape.data()); + } + + ArrayRef getShape() const { + return ArrayRef(shapeElements, shapeSize); + } + + const int64_t *shapeElements; + unsigned shapeSize; + Type elementType; +}; + +/// Scalable vector types represent multi-dimensional SIMD vectors that will be +/// processed by a scalable vector length processor. They have a fixed +/// known constant shape with one or more dimensions. +class ScalableVectorType : public Type::TypeBase { +public: + using Base::Base; + + /// Get or create a new ScalableVectorType of the provided shape and element + /// type. Assumes the arguments define a well-formed ScalableVectorType. + static ScalableVectorType get(ArrayRef shape, Type elementType); + + /// Get or create a new ScalableVectorType of the provided shape and element + /// type declared at the given, potentially unknown, location. If the + /// ScalableVectorType defined by the arguments would be ill-formed, emit + /// errors and return nullptr-wrapping type. + static ScalableVectorType getChecked(ArrayRef shape, + Type elementType, Location location); + + static LogicalResult verifyConstructionInvariants(Location loc, + ArrayRef shape, + Type elementType); + + static bool isValidElementType(Type t) { + return t.isa(); + } + + ArrayRef getShape() const; + + Type getElementType() const; +}; + +} // namespace sve +} // namespace mlir + +#include "mlir/Dialect/SVE/SVEDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/SVE/SVE.h.inc" + +#endif // MLIR_DIALECT_SVE_SVEDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" @@ -31,6 +32,7 @@ #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SVE/SVEDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -49,6 +51,7 @@ gpu::GPUDialect, LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, + LLVM::LLVMSVEDialect, linalg::LinalgDialect, scf::SCFDialect, omp::OpenMPDialect, @@ -57,6 +60,7 @@ quant::QuantizationDialect, spirv::SPIRVDialect, StandardOpsDialect, + sve::SVEDialect, vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -23,6 +23,7 @@ void registerToNVVMIRTranslation(); void registerToROCDLIRTranslation(); void registerAVX512ToLLVMIRTranslation(); +void registerSVEToLLVMIRTranslation(); // This function should be called before creating any MLIRContext if one // expects all the possible translations to be made available to the context @@ -36,6 +37,7 @@ registerToNVVMIRTranslation(); registerToROCDLIRTranslation(); registerAVX512ToLLVMIRTranslation(); + registerSVEToLLVMIRTranslation(); return true; }(); (void)initOnce; diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(SVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -27,6 +27,7 @@ namespace LLVM { class LLVMDialect; class LLVMAVX512Dialect; +class LLVMSVEDialect; } // end namespace LLVM namespace NVVM { diff --git a/mlir/lib/Conversion/SVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SVEToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SVEToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRSVEToLLVM + ConvertSVEToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SVEToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRSVE + MLIRLLVMSVE + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/SVEToLLVM/ConvertSVEToLLVM.cpp b/mlir/lib/Conversion/SVEToLLVM/ConvertSVEToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SVEToLLVM/ConvertSVEToLLVM.cpp @@ -0,0 +1,256 @@ +//===- ConvertSVEToLLVM.cpp - Convert SVE to the LLVM dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SVEToLLVM/ConvertSVEToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" +#include "mlir/Dialect/SVE/SVEDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::sve; + +template +static Type getSrc1VectorElementType(OpTy op) { + return op.src1() + .getType() + .template cast() + .getElementType(); +} +template +static Type getSrc2VectorElementType(OpTy op) { + return op.src2() + .getType() + .template cast() + .getElementType(); +} +template +static Type getAccVectorElementType(OpTy op) { + return op.acc() + .getType() + .template cast() + .getElementType(); +} + +/// Basic lowering implementation for one-to-one rewriting from SVE Ops to +/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass +/// operands as is, preserve attributes. +template +static LogicalResult +matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, + LLVMTypeConverter &typeConverter, Operation *op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + auto newOp = rewriter.create(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), success(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + success(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return success(); +} + +namespace { + +struct UmmlaOpConversion : public ConvertToLLVMPattern { + explicit UmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct SmmlaOpConversion : public ConvertToLLVMPattern { + explicit SmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct SdotOpConversion : public ConvertToLLVMPattern { + explicit SdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct UdotOpConversion : public ConvertToLLVMPattern { + explicit UdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct VscaleOpConversion : public ConvertToLLVMPattern { + explicit VscaleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(VscaleOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +} // namespace + +/// Populate the given list with patterns that convert from SVE to LLVM. +void mlir::populateSVEToLLVMConversionPatterns( + SVETypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + // clang-format off + patterns.insert(ctx, converter); + // clang-format on +} + +namespace { +struct ConvertSVEToLLVMPass + : public ConvertSVEToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +// Extract an LLVM IR type from the LLVM IR dialect type. +static LLVM::LLVMType unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; +} + +static Optional convertScalableVectorType(ScalableVectorType svType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(svType.getElementType())); + if (!elementType) + return {}; + + auto sVectorType = + LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); + return sVectorType; +} + +SVETypeConverter::SVETypeConverter(MLIRContext *ctx) : LLVMTypeConverter(ctx) { + addConversion([this](ScalableVectorType svType) { + return convertScalableVectorType(svType, *this); + }); +} + +void ConvertSVEToLLVMPass::runOnOperation() { + OwningRewritePatternList patterns; + SVETypeConverter sveConverter(&getContext()); + populateSVEToLLVMConversionPatterns(sveConverter, patterns); + populateVectorToLLVMConversionPatterns(sveConverter, patterns); + populateStdToLLVMConversionPatterns(sveConverter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr> mlir::createConvertSVEToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(SVE) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -49,6 +49,27 @@ MLIRSideEffectInterfaces ) +add_mlir_dialect_library(MLIRLLVMSVE + IR/LLVMSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMSVEIncGen + MLIRLLVMSVEConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMSVEDialect.cpp @@ -0,0 +1,31 @@ +//===- LLVMSVEDialect.cpp - MLIR LLVMSVE ops implementation ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMSVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsAArch64.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMSVE.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMSVE.cpp.inc" diff --git a/mlir/lib/Dialect/SVE/CMakeLists.txt b/mlir/lib/Dialect/SVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SVE/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRSVE + IR/SVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SVE + + DEPENDS + MLIRSVEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + MLIRVectorToLLVM + ) diff --git a/mlir/lib/Dialect/SVE/IR/SVEDialect.cpp b/mlir/lib/Dialect/SVE/IR/SVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SVE/IR/SVEDialect.cpp @@ -0,0 +1,99 @@ +//===- SVEDialect.cpp - MLIR SVE dialect implementation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the SVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SVE/SVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +void sve::SVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/SVE/SVE.cpp.inc" + >(); + addTypes(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/SVE/SVE.cpp.inc" + +namespace mlir { +namespace sve { + +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +ScalableVectorType ScalableVectorType::get(ArrayRef shape, + Type elementType) { + return Base::get(elementType.getContext(), shape, elementType); +} + +ScalableVectorType ScalableVectorType::getChecked(ArrayRef shape, + Type elementType, + Location location) { + return Base::getChecked(location, shape, elementType); +} + +ArrayRef ScalableVectorType::getShape() const { + return getImpl()->getShape(); +} + +Type ScalableVectorType::getElementType() const { + return getImpl()->elementType; +} + +LogicalResult ScalableVectorType::verifyConstructionInvariants( + Location loc, ArrayRef shape, Type elementType) { + if (shape.empty()) + return emitError(loc, + "scalable vector types must have at least one dimension"); + + if (!isValidElementType(elementType)) + return emitError(loc, "vector elements must be int or float type"); + + if (any_of(shape, [](int64_t i) { return i <= 0; })) + return emitError(loc, "vector types must have positive constant sizes"); + + return success(); +} + +Type SVEDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + VectorType vectorTypeType; + if (parser.parseType(vectorTypeType)) { + parser.emitError(typeLoc, "unknown type in SVE dialect"); + return Type(); + } + return ScalableVectorType::get(vectorTypeType.getShape(), + vectorTypeType.getElementType()); +} + +void SVEDialect::printType(Type type, DialectAsmPrinter &os) const { + TypeSwitch(type) + .Case([&](ScalableVectorType svTy) { + os << "vector<"; + for (int64_t dim : svTy.getShape()) + os << dim << 'x'; + os << svTy.getElementType() << '>'; + }) + .Default([](Type) { llvm_unreachable("unexpected 'sve' type kind"); }); +} + +} // namespace sve +} // namespace mlir diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -94,3 +94,22 @@ MLIRROCDLIR MLIRTargetLLVMIRModuleTranslation ) + +add_mlir_translation_library(MLIRTargetSVE + LLVMIR/LLVMSVEIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + + DEPENDS + MLIRLLVMSVEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMSVE + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/mlir/lib/Target/LLVMIR/LLVMSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMSVEIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMSVEIntr.cpp @@ -0,0 +1,63 @@ +//===- LLVMSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and SVE dialects and +// LLVM IR with SVE intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMSVEDialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; + +namespace { +class LLVMSVEModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMSVEConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr +translateLLVMSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); +} +} // end namespace + +namespace mlir { +void registerSVEToLLVMIRTranslation() { + TranslateFromMLIRRegistration reg( + "sve-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMSVEModuleToLLVMIR(module, llvmContext, + "LLVMDialectModule"); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} +} // namespace mlir diff --git a/mlir/test/Conversion/SVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/SVEToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SVEToLLVM/convert-to-llvm.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s -convert-sve-to-llvm | mlir-opt | FileCheck %s + +func @sve_sdot(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) + -> !sve.vector<4xsi32> +{ + // CHECK: llvm_sve.sdot + %0 = sve.sdot %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_smmla(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) + -> !sve.vector<4xsi32> +{ + // CHECK: llvm_sve.smmla + %0 = sve.smmla %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_udot(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) + -> !sve.vector<4xui32> +{ + // CHECK: llvm_sve.udot + %0 = sve.udot %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @sve_ummla(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) + -> !sve.vector<4xui32> +{ + // CHECK: llvm_sve.ummla + %0 = sve.ummla %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @get_vscale() -> index +{ + // CHECK: llvm_sve.vscale + %0 = sve.vscale : index + return %0 : index +} diff --git a/mlir/test/Dialect/SVE/roundtrip.mlir b/mlir/test/Dialect/SVE/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SVE/roundtrip.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @sve_sdot(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) -> !sve.vector<4xsi32> +{ + // CHECK: sve.sdot {{.*}}: !sve.vector<16xsi8> -> !sve.vector<4xsi32> + %0 = sve.sdot %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_smmla(%a: !sve.vector<16xsi8>, + %b: !sve.vector<16xsi8>, + %c: !sve.vector<4xsi32>) -> !sve.vector<4xsi32> +{ + // CHECK: sve.smmla {{.*}}: !sve.vector<16xsi8> -> !sve.vector<4xsi32> + %0 = sve.smmla %c, %a, %b : !sve.vector<16xsi8> -> !sve.vector<4xsi32> + return %0 : !sve.vector<4xsi32> +} + +func @sve_udot(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) -> !sve.vector<4xui32> +{ + // CHECK: sve.udot {{.*}}: !sve.vector<16xui8> -> !sve.vector<4xui32> + %0 = sve.udot %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @sve_ummla(%a: !sve.vector<16xui8>, + %b: !sve.vector<16xui8>, + %c: !sve.vector<4xui32>) -> !sve.vector<4xui32> +{ + // CHECK: sve.ummla {{.*}}: !sve.vector<16xui8> -> !sve.vector<4xui32> + %0 = sve.ummla %c, %a, %b : !sve.vector<16xui8> -> !sve.vector<4xui32> + return %0 : !sve.vector<4xui32> +} + +func @get_vscale() -> index +{ + // CHECK: sve.vscale : index + %0 = sve.vscale : index + return %0 : index +} diff --git a/mlir/test/Target/sve.mlir b/mlir/test/Target/sve.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/sve.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --sve-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define @sve_sdot +llvm.func @sve_sdot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @sve_smmla +llvm.func @sve_smmla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @sve_udot +llvm.func @sve_udot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @sve_ummla +llvm.func @sve_ummla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define i64 @get_vscale() +llvm.func @get_vscale() -> !llvm.i64 { + // CHECK: call i64 @llvm.vscale.i64() + %0 = "llvm_sve.vscale"() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 +}