diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h @@ -0,0 +1,23 @@ +//===- ConvertArmSVEToLLVM.h - Conversion Patterns from Arm SVE to LLVM ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMSVETOLLVM_CONVERTARMSVETOLLVM_H_ +#define MLIR_CONVERSION_ARMSVETOLLVM_CONVERTARMSVETOLLVM_H_ + +namespace mlir { + +class LLVMTypeConverter; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the Arm SVE dialect to LLVM. +void populateArmSVEToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMSVETOLLVM_CONVERTARMSVETOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -401,7 +401,8 @@ }]; let constructor = "mlir::createConvertVectorToLLVMPass()"; - let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"]; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect", + "LLVM::LLVMArmSVEDialect"]; let options = [ Option<"reassociateFPReductions", "reassociate-fp-reductions", "bool", /*default=*/"false", @@ -413,7 +414,11 @@ Option<"enableAVX512", "enable-avx512", "bool", /*default=*/"false", "Enables the use of AVX512 dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"enableArmSVE", "enable-arm-sve", + "bool", /*default=*/"true", + "Enables the use of ArmSVE dialect while lowering the vector " + "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,7 +23,7 @@ struct LowerVectorToLLVMOptions { LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), - enableAVX512(false) {} + enableAVX512(false), enableArmSVE(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -37,10 +37,15 @@ enableAVX512 = b; return *this; } + LowerVectorToLLVMOptions &setEnableArmSVE(bool b) { + enableArmSVE = b; + return *this; + } bool reassociateFPReductions; bool enableIndexOptimizations; bool enableAVX512; + bool enableArmSVE; }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,5 +13,6 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Target) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -30,3 +30,9 @@ set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) + +add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE) + +set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td) +mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td @@ -0,0 +1,75 @@ +//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMArmSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_ARM_SVE_OPS +#define LLVMIR_ARM_SVE_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMArmSVE dialect definition +//===----------------------------------------------------------------------===// + +def LLVMArmSVE_Dialect : Dialect { + let name = "llvm_arm_sve"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +class LLVMArmSVE_IntrOp traits = []> : + LLVM_IntrOpBase; + +class LLVMArmSVE_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +def LLVM_aarch64_sve_ummla : + LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_smmla : + LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_sdot : + LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_udot : + LLVMArmSVE_IntrBinaryOverloadedOp<"udot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_vector_scale : + LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + +#endif // ARM_SVE_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h @@ -0,0 +1,24 @@ +//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMSVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVE.td @@ -0,0 +1,238 @@ +//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef ARM_SVE_OPS +#define ARM_SVE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// ArmSVE dialect definition +//===----------------------------------------------------------------------===// + +def ArmSVE_Dialect : Dialect { + let name = "arm_sve"; + let cppNamespace = "::mlir::arm_sve"; + let summary = "Basic dialect to target Arm SVE architectures"; + let description = [{ + This dialect contains the definitions necessary to target Arm SVE scalable + vector operations, including a scalable vector type and intrinsics for + some Arm SVE instructions. + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSVE type definitions +//===----------------------------------------------------------------------===// + +def ArmSVE_ScalableVectorType : DialectType()">, + "scalable vector type">, + BuildableType<"$_builder.getType()"> { + let typeDescription = [{ + `arm_sve.vector` represents vectors that will be processed by a scalable + vector architecture. + }]; +} + +class ArmSVE_Type : TypeDef { } + +def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { + let mnemonic = "vector"; + + let summary = "Scalable vector type"; + + let description = [{ + A type representing scalable length SIMD vectors. Unlike fixed-length SIMD + vectors, whose size is constant and known at compile time, scalable + vectors' length is constant but determined by the specific hardware at + run time. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t", "Vector shape">:$shape, + "Type":$elementType + ); + + let printer = [{ + $_printer << "vector<"; + for (int64_t dim : getShape()) + $_printer << dim << 'x'; + $_printer << getElementType() << '>'; + }]; + + let parser = [{ + VectorType vector; + if ($_parser.parseType(vector)) + return Type(); + return get(ctxt, vector.getShape(), vector.getElementType()); + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSVE type traits +//===----------------------------------------------------------------------===// + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; + +class ScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", + "scalable vector">; + +//===----------------------------------------------------------------------===// +// ArmSVE op definitions +//===----------------------------------------------------------------------===// + +class ArmSVE_Op traits = []> : + Op {} + +def UmmlaOp : ArmSVE_Op<"ummla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + UMMLA: Unsigned integer matrix multiply-accumulate. + + These function maps to the UMMLA instruction. It partitions the inputs into + 128-bit quadwords, with the first input containing a row-by-row 2×2 matrix + of 32-bit unsigned integers, the second input containing a row-by-row 2×8 + matrix of 8-bit signed integers, and the third input containing a + column-by-column 8×2 matrix of 8-bit unsigned integers. For each quadword, + they multiply the second input matrix by the third input matrix using + natural arithmetic and then add the result to the first input using modular + arithmetic. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports vector<16xi8>. + let arguments = (ins + ScalableVectorOf<[I32]>:$acc, + ScalableVectorOf<[I8]>:$src1, + ScalableVectorOf<[I8]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SmmlaOp : ArmSVE_Op<"smmla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + SMMLA: Signed integer matrix multiply-accumulate. + + These function maps to the SMMLA instruction. It partitions the inputs into + 128-bit quadwords, with the first input containing a row-by-row 2×2 matrix + of 32-bit signed integers, the second input containing a row-by-row 2×8 + matrix of 8-bit signed integers, and the third input containing a + column-by-column 8×2 matrix of 8-bit signed integers. For each quadword, + they multiply the second input matrix by the third input matrix using + natural arithmetic and then add the result to the first input using modular + arithmetic. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports vector<16 x si8>. + let arguments = (ins + ScalableVectorOf<[I32]>:$acc, + ScalableVectorOf<[I8]>:$src1, + ScalableVectorOf<[I8]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SdotOp : ArmSVE_Op<"sdot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + SDOT: Signed integer addition of dot product. + + This function partitions the second and third vector inputs into groups of + four elements. They calculate the dot product of each group (without loss + of precision) and then add each result to the overlapping element of the + first vector input. This is the version for signed integers. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[I32, I64]>:$acc, + ScalableVectorOf<[I8, I16]>:$src1, + ScalableVectorOf<[I8, I16]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32, I64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def UdotOp : ArmSVE_Op<"udot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + UDOT: Unsigned integer addition of dot product. + + This function partitions the second and third vector inputs into groups of + four elements. They calculate the dot product of each group (without loss + of precision) and then add each result to the overlapping element of the + first vector input. This is the version for unsigned integers. + + Source: + https://developer.arm.com/documentation/100987/0000 + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[I32, I64]>:$acc, + ScalableVectorOf<[I8, I16]>:$src1, + ScalableVectorOf<[I8, I16]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32, I64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def VectorScaleOp : ArmSVE_Op<"vector_scale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The vector_scale op returns the scale of the scalable vectors, a positive + integer value that is constant at runtime but unknown at compile time. + The scale of the vector indicates the multiplicity of the vectors and + vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to + vector_scale consecutive vector<4xi32>; and an operation on an + !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale + times, once on each <4xi32> segment of the scalable vector. The vector_scale + op can be used to calculate the step in vector-length agnostic (VLA) loops. + }]; + let results = (outs Index:$res); + let assemblyFormat = + "attr-dict `:` type($res)"; +} + +#endif // ARM_SVE_OPS diff --git a/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h @@ -0,0 +1,29 @@ +//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for Arm SVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TARGET_ARMSVE_ARMSVEDIALECT_H +#define MLIR_DIALECT_TARGET_ARMSVE_ARMSVEDIALECT_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVETypes.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVE.h.inc" + +#endif // MLIR_DIALECT_TARGET_ARMSVE_ARMSVEDIALECT_H diff --git a/mlir/include/mlir/Dialect/Target/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/Target/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/ArmSVE/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(ArmSVE arm_sve ArmSVE) diff --git a/mlir/include/mlir/Dialect/Target/CMakeLists.txt b/mlir/include/mlir/Dialect/Target/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ArmSVE) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -33,6 +34,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -49,6 +51,7 @@ gpu::GPUDialect, LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, + LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, scf::SCFDialect, omp::OpenMPDialect, @@ -57,6 +60,7 @@ quant::QuantizationDialect, spirv::SPIRVDialect, StandardOpsDialect, + arm_sve::ArmSVEDialect, vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -23,6 +23,7 @@ void registerToNVVMIRTranslation(); void registerToROCDLIRTranslation(); void registerAVX512ToLLVMIRTranslation(); +void registerArmSVEToLLVMIRTranslation(); // This function should be called before creating any MLIRContext if one // expects all the possible translations to be made available to the context @@ -36,6 +37,7 @@ registerToNVVMIRTranslation(); registerToROCDLIRTranslation(); registerAVX512ToLLVMIRTranslation(); + registerArmSVEToLLVMIRTranslation(); return true; }(); (void)initOnce; diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArmSVEToLLVM + ConvertArmSVEToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmSVE + MLIRLLVMArmSVE + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.cpp @@ -0,0 +1,221 @@ +//===- ConvertArmSVEToLLVM.cpp - Convert Arm SVE to the LLVM dialect ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::arm_sve; + +template +static Type getSrc1VectorElementType(OpTy op) { + return op.src1() + .getType() + .template cast() + .getElementType(); +} +template +static Type getSrc2VectorElementType(OpTy op) { + return op.src2() + .getType() + .template cast() + .getElementType(); +} +template +static Type getAccVectorElementType(OpTy op) { + return op.acc() + .getType() + .template cast() + .getElementType(); +} + +/// Basic lowering implementation for one-to-one rewriting from Arm SVE Ops to +/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass +/// operands as is, preserve attributes. +template +static LogicalResult +matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, + LLVMTypeConverter &typeConverter, Operation *op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + auto newOp = rewriter.create(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), success(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + success(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return success(); +} + +namespace { + +struct UmmlaOpConversion : public ConvertToLLVMPattern { + explicit UmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, *getTypeConverter(), op, operands, rewriter); + } +}; + +struct SmmlaOpConversion : public ConvertToLLVMPattern { + explicit SmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, *getTypeConverter(), op, operands, rewriter); + } +}; + +struct SdotOpConversion : public ConvertToLLVMPattern { + explicit SdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, *getTypeConverter(), op, operands, rewriter); + } +}; + +struct UdotOpConversion : public ConvertToLLVMPattern { + explicit UdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, *getTypeConverter(), op, operands, rewriter); + } +}; + +struct VectorScaleOpConversion : public ConvertToLLVMPattern { + explicit VectorScaleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(VectorScaleOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + return matchAndRewriteOneToOne( + *this, *getTypeConverter(), op, operands, rewriter); + } +}; + +} // namespace + +// Extract an LLVM IR type from the LLVM IR dialect type. +static LLVM::LLVMType unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; +} + +static Optional +convertScalableVectorTypeToLLVM(ScalableVectorType svType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(svType.getElementType())); + if (!elementType) + return {}; + + auto sVectorType = + LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); + return sVectorType; +} + +/// Populate the given list with patterns that convert from Arm SVE to LLVM. +void mlir::populateArmSVEToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + converter.addConversion([&converter](ScalableVectorType svType) { + return convertScalableVectorTypeToLLVM(svType, converter); + }); + // clang-format off + patterns.insert(ctx, converter); + // clang-format on +} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -27,6 +27,7 @@ namespace LLVM { class LLVMDialect; class LLVMAVX512Dialect; +class LLVMArmSVEDialect; } // end namespace LLVM namespace NVVM { diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -16,6 +16,9 @@ MLIRAVX512 MLIRAVX512ToLLVM MLIRLLVMAVX512 + MLIRArmSVE + MLIRArmSVEToLLVM + MLIRLLVMArmSVE MLIRLLVMIR MLIRStandardToLLVM MLIRTargetLLVMIRModuleTranslation diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -11,11 +11,14 @@ #include "../PassDetail.h" #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" +#include "mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -29,6 +32,7 @@ this->reassociateFPReductions = options.reassociateFPReductions; this->enableIndexOptimizations = options.enableIndexOptimizations; this->enableAVX512 = options.enableAVX512; + this->enableArmSVE = options.enableArmSVE; } void runOnOperation() override; }; @@ -61,6 +65,11 @@ target.addIllegalDialect(); populateAVX512ToLLVMConversionPatterns(converter, patterns); } + if (enableArmSVE) { + target.addLegalDialect(); + target.addIllegalDialect(); + populateArmSVEToLLVMConversionPatterns(converter, patterns); + } if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Target) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -49,6 +49,27 @@ MLIRSideEffectInterfaces ) +add_mlir_dialect_library(MLIRLLVMArmSVE + IR/LLVMArmSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMArmSVEIncGen + MLIRLLVMArmSVEConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp @@ -0,0 +1,31 @@ +//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMArmSVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsAArch64.h" + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMArmSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" diff --git a/mlir/lib/Dialect/Target/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/Target/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Target/ArmSVE/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRArmSVE + IR/ArmSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE + + DEPENDS + MLIRArmSVEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/Target/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/Target/ArmSVE/IR/ArmSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Target/ArmSVE/IR/ArmSVEDialect.cpp @@ -0,0 +1,63 @@ +//===- ArmSVEDialect.cpp - MLIR ARM SVE dialect implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Arm SVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +void arm_sve::ArmSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Target/ArmSVE/ArmSVE.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Target/ArmSVE/ArmSVETypes.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVE.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVETypes.cpp.inc" + +namespace mlir { +namespace arm_sve { + +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +Type ArmSVEDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + auto genType = generatedTypeParser(getContext(), parser, "vector"); + if (genType != Type()) + return genType; + parser.emitError(typeLoc, "unknown type in Arm SVE dialect"); + return Type(); +} + +void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'arm_sve' type kind"); +} + +} // namespace arm_sve +} // namespace mlir diff --git a/mlir/lib/Dialect/Target/CMakeLists.txt b/mlir/lib/Dialect/Target/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ArmSVE) diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -94,3 +94,22 @@ MLIRROCDLIR MLIRTargetLLVMIRModuleTranslation ) + +add_mlir_translation_library(MLIRTargetArmSVE + LLVMIR/LLVMArmSVEIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + + DEPENDS + MLIRLLVMArmSVEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMArmSVE + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp @@ -0,0 +1,63 @@ +//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and Arm SVE dialects +// and LLVM IR with Arm SVE intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; + +namespace { +class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr +translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); +} +} // end namespace + +namespace mlir { +void registerArmSVEToLLVMIRTranslation() { + TranslateFromMLIRRegistration reg( + "arm-sve-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMArmSVEModuleToLLVMIR( + module, llvmContext, "LLVMDialectModule"); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} +} // namespace mlir diff --git a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sve" | mlir-opt | FileCheck %s + +func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.sdot + %0 = arm_sve.sdot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.smmla + %0 = arm_sve.smmla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.udot + %0 = arm_sve.udot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.ummla + %0 = arm_sve.ummla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @get_vector_scale() -> index +{ + // CHECK: llvm_arm_sve.vscale + %0 = arm_sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Dialect/Target/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/Target/ArmSVE/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Target/ArmSVE/roundtrip.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32 + %0 = arm_sve.sdot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi3 + %0 = arm_sve.smmla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32 + %0 = arm_sve.udot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi3 + %0 = arm_sve.ummla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @get_vector_scale() -> index +{ + // CHECK: arm_sve.vector_scale : index + %0 = arm_sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Target/armsve.mlir b/mlir/test/Target/armsve.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/armsve.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define @arm_sve_sdot +llvm.func @arm_sve_sdot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_smmla +llvm.func @arm_sve_smmla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_udot +llvm.func @arm_sve_udot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_ummla +llvm.func @arm_sve_ummla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define i64 @get_vector_scale() +llvm.func @get_vector_scale() -> !llvm.i64 { + // CHECK: call i64 @llvm.vscale.i64() + %0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 +}