diff --git a/mlir/include/mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h @@ -0,0 +1,40 @@ +//===- ConvertArmSVEToLLVM.h - Conversion Patterns from Arm SVE to LLVM ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARM_SVETOLLVM_CONVERTARMSVETOLLVM_H_ +#define MLIR_CONVERSION_ARM_SVETOLLVM_CONVERTARMSVETOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include + +namespace mlir { +class LLVMTypeConverter; +class ModuleOp; +template +class OperationPass; +class OwningRewritePatternList; + +//===----------------------------------------------------------------------===// +// Arm SVE Scalar Vector Type Conversion +//===----------------------------------------------------------------------===// + +class ArmSVETypeConverter : public LLVMTypeConverter { +public: + explicit ArmSVETypeConverter(MLIRContext *ctx); +}; + +/// Collect a set of patterns to convert from the Arm SVE dialect to LLVM. +void populateArmSVEToLLVMConversionPatterns(ArmSVETypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert Arm SVE operations to the LLVMIR dialect. +std::unique_ptr> createConvertArmSVEToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARM_SVETOLLVM_CONVERTARMSVETOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -11,6 +11,7 @@ #include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -379,6 +379,17 @@ let dependentDialects = ["spirv::SPIRVDialect"]; } +//===----------------------------------------------------------------------===// +// ArmSVEToLLVM +//===----------------------------------------------------------------------===// + +def ConvertArmSVEToLLVM : Pass<"convert-arm-sve-to-llvm", "ModuleOp"> { + let summary = "Convert the operations from the arm_sve dialect into the LLVM " + "dialect"; + let constructor = "mlir::createConvertArmSVEToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMArmSVEDialect"]; +} + //===----------------------------------------------------------------------===// // VectorToSCF //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,5 +13,6 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Target) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -30,3 +30,9 @@ set(LLVM_TARGET_DEFINITIONS LLVMAVX512.td) mlir_tablegen(LLVMAVX512Conversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRLLVMAVX512ConversionsIncGen) + +add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE) + +set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td) +mlir_tablegen(LLVMArmSVEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMArmSVEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVE.td @@ -0,0 +1,75 @@ +//===-- LLVMArmSVE.td - LLVMARMSVE dialect op definitions --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMArmSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_ARM_SVE_OPS +#define LLVMIR_ARM_SVE_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMArmSVE dialect definition +//===----------------------------------------------------------------------===// + +def LLVMArmSVE_Dialect : Dialect { + let name = "llvm_arm_sve"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM Arm SVE intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +class LLVMArmSVE_NonSVEIntrUnaryOverloadedOp traits =[]> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +class LLVMArmSVE_IntrOp traits = []> : + LLVM_IntrOpBase; + +class LLVMArmSVE_IntrBinaryOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], // defined by result overload + /*list traits=*/traits, + /*int numResults=*/1>; + +def LLVM_aarch64_sve_ummla : + LLVMArmSVE_IntrBinaryOverloadedOp<"ummla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_smmla : + LLVMArmSVE_IntrBinaryOverloadedOp<"smmla">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_sdot : + LLVMArmSVE_IntrBinaryOverloadedOp<"sdot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_aarch64_sve_udot : + LLVMArmSVE_IntrBinaryOverloadedOp<"udot">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_vector_scale : + LLVMArmSVE_NonSVEIntrUnaryOverloadedOp<"vscale">; + +#endif // ARM_SVE_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h @@ -0,0 +1,24 @@ +//===- LLVMSVEDialect.h - MLIR Dialect for LLVMSVE --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMSVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMARMSVEDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVE.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVE.td @@ -0,0 +1,255 @@ +//===-- ArmSVE.td - ArmSVE dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmSVE dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef ARM_SVE_OPS +#define ARM_SVE_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// ArmSVE dialect definition +//===----------------------------------------------------------------------===// + +def ArmSVE_Dialect : Dialect { + let name = "arm_sve"; + let cppNamespace = "::mlir::arm_sve"; + let summary = "Basic dialect to target Arm SVE architectures"; + let description = [{ + This dialect contains the definitions necessary to target Arm SVE scalable + vector operations, including a scalable vector type and intrinsics for + some Arm SVE instructions. + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSVE type definitions +//===----------------------------------------------------------------------===// + +def ArmSVE_ScalableVectorType : DialectType()">, + "scalable vector type">, + BuildableType<"$_builder.getType()"> { + let typeDescription = [{ + `arm_sve.vector` represents vectors that will be processed by a scalable + vector architecture. + }]; +} + +class ArmSVE_Type : TypeDef { } + +def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { + let mnemonic = "vector"; + + let summary = "Scalable vector type"; + + let description = [{ + A type representing scalable length SIMD vectors. Unlike fixed-length SIMD + vectors, whose size is constant and known at compile time, scalable + vectors' length is constant but determined by the specific hardware at + run time. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t", "Vector shape">:$shape, + "Type":$elementType + ); + + let printer = [{ + $_printer << "vector<"; + for (int64_t dim : getShape()) + $_printer << dim << 'x'; + $_printer << getElementType() << '>'; + }]; + + let parser = [{ + VectorType vector; + if ($_parser.parseType(vector)) + return Type(); + return get(ctxt, vector.getShape(), vector.getElementType()); + }]; +} + +//===----------------------------------------------------------------------===// +// ArmSVE type traits +//===----------------------------------------------------------------------===// + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::arm_sve::ScalableVectorType>()">; + +class ScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::arm_sve::ScalableVectorType>().getElementType()", + "scalable vector">; + +//===----------------------------------------------------------------------===// +// ArmSVE op definitions +//===----------------------------------------------------------------------===// + +class ArmSVE_Op traits = []> : + Op {} + +def UmmlaOp : ArmSVE_Op<"ummla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + The ummla op is an Arm SVE specific op that can lower to the proper + LLVMArmSVE operation: `llvm.aarch64.sve.ummla` instruction. + + UMMLA: Unsigned integer matrix multiply-accumulate. + + The unsigned integer matrix multiply-accumulate instruction multiplies the + 2×8 matrix of unsigned 8-bit integer values held in each <16xui8> segment + of the `$src1` vector by the 8×2 matrix of unsigned 8-bit integer values + in the corresponding <16xui8> segment of the `$src2` vector. The resulting + 2×2 widened 32-bit integer matrix product is then destructively added to the + 32-bit integer matrix accumulator held in the corresponding <4xui32> segment + of the `$acc` vector. This is equivalent to performing an 8-way dot product + per destination element. + + Elements of the `$src1` matrix are stored in row-major order, in blocks of + 2x8 elements per segment. Elements of the `$src2` matrix are stored in + column-major order, in blocks of 8x2 elements per segment. The elements of + the `$acc` matrix are stored in row-major order, in blocks of 2x2 elements + per segment. + }]; + // Supports vector<16xi8>. + let arguments = (ins + ScalableVectorOf<[I32]>:$acc, + ScalableVectorOf<[I8]>:$src1, + ScalableVectorOf<[I8]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SmmlaOp : ArmSVE_Op<"smmla", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Matrix-matrix mutiply and accumulate op"; + let description = [{ + The smmla op is an Arm SVE specific op that can lower to the proper + LLVMArmSVE operation: `llvm.aarch64.sve.smmla` instruction. + + SMMLA: Signed integer matrix multiply-accumulate. + + The signed integer matrix multiply-accumulate instruction multiplies the + 2×8 matrix of signed 8-bit integer values held in each <16xsi8> segment + of the `$src1` vector by the 8×2 matrix of signed 8-bit integer values + in the corresponding <16xsi8> segment of the `$src2` vector. The resulting + 2×2 widened 32-bit integer matrix product is then destructively added to + the 32-bit integer matrix accumulator held in the corresponding <4xsi32> + segment of the `$acc` vector. This is equivalent to performing an 8-way dot + product per destination element. + + Elements of the `$src1` matrix are stored in row-major order, in blocks of + 2x8 elements per segment. Elements of the `$src2` matrix are stored in + column-major order, in blocks of 8x2 elements per segment. The elements of + the `$acc` matrix are stored in row-major order, in blocks of 2x2 elements + per segment. + }]; + // Supports vector<16 x si8>. + let arguments = (ins + ScalableVectorOf<[I32]>:$acc, + ScalableVectorOf<[I8]>:$src1, + ScalableVectorOf<[I8]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def SdotOp : ArmSVE_Op<"sdot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + The sdot op is an Arm SVE specific op that can lower to the proper + LLVMArmSVE operation: `llvm.aarch64.sve.sdot` + + SDOT: Signed integer dot product. + + The signed integer dot product instruction computes the dot product of a + group of four signed 8-bit or 16-bit integer values held in each 32-bit or + 64-bit element of the `$src1` vector multiplied by a group of four signed + 8-bit or 16-bit integer values in the corresponding 32-bit or 64-bit + element of the `$src2` vector, and then destructively adds the widened dot + product to the corresponding 32-bit or 64-bit element of the `$acc` vector. + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[I32, I64]>:$acc, + ScalableVectorOf<[I8, I16]>:$src1, + ScalableVectorOf<[I8, I16]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32, I64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def UdotOp : ArmSVE_Op<"udot", + [NoSideEffect, + AllTypesMatch<["src1", "src2"]>, + AllTypesMatch<["acc", "dst"]>, + ]> { + let summary = "Vector-vector dot product and accumulate op"; + let description = [{ + The udot op is an Arm SVE specific op that can lower to the proper + LLVMArmSVE operation: `llvm.aarch64.sve.udot` + + UDOT: Unsigned integer dot product. + + The unsigned integer dot product instruction computes the dot product of a + group of four unsigned 8-bit or 16-bit integer values held in each 32-bit + or 64-bit element of the `$src1` vector multiplied by a group of four + unsigned 8-bit or 16-bit integer values in the corresponding 32-bit or + 64-bit element of the `$src2` vector, and then destructively adds the + widened dot product to the corresponding 32-bit or 64-bit element of the + `$acc` vector. + }]; + // Supports vector<16xi8> and vector<8xi16>. + let arguments = (ins + ScalableVectorOf<[I32, I64]>:$acc, + ScalableVectorOf<[I8, I16]>:$src1, + ScalableVectorOf<[I8, I16]>:$src2 + ); + let results = (outs ScalableVectorOf<[I32, I64]>:$dst); + let assemblyFormat = + "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `->` type($dst)"; +} + +def VectorScaleOp : ArmSVE_Op<"vector_scale", + [NoSideEffect]> { + let summary = "Load vector scale size"; + let description = [{ + The vector_scale op returns the scale of the scalable vectors, a positive + integer value that is constant at runtime but unknown at compile time. + The scale of the vector indicates the multiplicity of the vectors and + vector operations. I.e.: an !arm_sve.vector<4xi32> is equivalent to + vector_scale consecutive vector<4xi32>; and an operation on an + !arm_sve.vector<4xi32> is equivalent to performing that operation vector_scale + times, once on each <4xi32> segment of the scalable vector. The vector_scale + op can be used to calculate the step in vector-length agnostic (VLA) loops. + }]; + let results = (outs Index:$res); + let assemblyFormat = + "attr-dict `:` type($res)"; +} + +#endif // ARM_SVE_OPS diff --git a/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h @@ -0,0 +1,29 @@ +//===- ArmSVEDialect.h - MLIR Dialect for Arm SVE ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for Arm SVE in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TARGET_ARM_SVE_DIALECT_H_ +#define MLIR_DIALECT_TARGET_ARM_SVE_DIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVETypes.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVE.h.inc" + +#endif // MLIR_DIALECT_TARGET_ARM_SVE_DIALECT_H_ diff --git a/mlir/include/mlir/Dialect/Target/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/Target/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/ArmSVE/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(ArmSVE arm_sve ArmSVE) diff --git a/mlir/include/mlir/Dialect/Target/CMakeLists.txt b/mlir/include/mlir/Dialect/Target/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ArmSVE) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -33,6 +34,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -49,6 +51,7 @@ gpu::GPUDialect, LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect, + LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, scf::SCFDialect, omp::OpenMPDialect, @@ -57,6 +60,7 @@ quant::QuantizationDialect, spirv::SPIRVDialect, StandardOpsDialect, + arm_sve::ArmSVEDialect, vector::VectorDialect, NVVM::NVVMDialect, ROCDL::ROCDLDialect, diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -23,6 +23,7 @@ void registerToNVVMIRTranslation(); void registerToROCDLIRTranslation(); void registerAVX512ToLLVMIRTranslation(); +void registerArmSVEToLLVMIRTranslation(); // This function should be called before creating any MLIRContext if one // expects all the possible translations to be made available to the context @@ -36,6 +37,7 @@ registerToNVVMIRTranslation(); registerToROCDLIRTranslation(); registerAVX512ToLLVMIRTranslation(); + registerArmSVEToLLVMIRTranslation(); return true; }(); (void)initOnce; diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSVEToLLVM/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRArmSVEToLLVM + ConvertArmSVEToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSVEToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArmSVE + MLIRLLVMArmSVE + MLIRLLVMIR + MLIRStandardToLLVM + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.cpp @@ -0,0 +1,257 @@ +//===- ConvertArmSVEToLLVM.cpp - Convert Arm SVE to the LLVM dialect ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmSVEToLLVM/ConvertArmSVEToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace mlir::arm_sve; + +template +static Type getSrc1VectorElementType(OpTy op) { + return op.src1() + .getType() + .template cast() + .getElementType(); +} +template +static Type getSrc2VectorElementType(OpTy op) { + return op.src2() + .getType() + .template cast() + .getElementType(); +} +template +static Type getAccVectorElementType(OpTy op) { + return op.acc() + .getType() + .template cast() + .getElementType(); +} + +/// Basic lowering implementation for one-to-one rewriting from Arm SVE Ops to +/// LLVM Dialect Ops. Convert the type of the result to an LLVM type, pass +/// operands as is, preserve attributes. +template +static LogicalResult +matchAndRewriteOneToOne(const ConvertToLLVMPattern &lowering, + LLVMTypeConverter &typeConverter, Operation *op, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + auto newOp = rewriter.create(op->getLoc(), packedType, operands, + op->getAttrs()); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), success(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + success(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return success(); +} + +namespace { + +struct UmmlaOpConversion : public ConvertToLLVMPattern { + explicit UmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct SmmlaOpConversion : public ConvertToLLVMPattern { + explicit SmmlaOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SmmlaOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!(getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct SdotOpConversion : public ConvertToLLVMPattern { + explicit SdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(SdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct UdotOpConversion : public ConvertToLLVMPattern { + explicit UdotOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(UdotOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!((getAccVectorElementType(cast(op)).isInteger(32) && + getSrc1VectorElementType(cast(op)).isInteger(8) && + getSrc2VectorElementType(cast(op)).isInteger(8)) || + (getAccVectorElementType(cast(op)).isInteger(64) && + getSrc1VectorElementType(cast(op)).isInteger(16) && + getSrc2VectorElementType(cast(op)).isInteger(16)))) + return failure(); + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +struct VectorScaleOpConversion : public ConvertToLLVMPattern { + explicit VectorScaleOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(VectorScaleOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + return matchAndRewriteOneToOne( + *this, this->typeConverter, op, operands, rewriter); + } +}; + +} // namespace + +/// Populate the given list with patterns that convert from Arm SVE to LLVM. +void mlir::populateArmSVEToLLVMConversionPatterns( + ArmSVETypeConverter &converter, OwningRewritePatternList &patterns) { + MLIRContext *ctx = converter.getDialect()->getContext(); + // clang-format off + patterns.insert(ctx, converter); + // clang-format on +} + +namespace { +struct ConvertArmSVEToLLVMPass + : public ConvertArmSVEToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +// Extract an LLVM IR type from the LLVM IR dialect type. +static LLVM::LLVMType unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + auto wrappedLLVMType = type.dyn_cast(); + if (!wrappedLLVMType) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return wrappedLLVMType; +} + +static Optional convertScalableVectorType(ScalableVectorType svType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(svType.getElementType())); + if (!elementType) + return {}; + + auto sVectorType = + LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back()); + return sVectorType; +} + +ArmSVETypeConverter::ArmSVETypeConverter(MLIRContext *ctx) + : LLVMTypeConverter(ctx) { + addConversion([this](ScalableVectorType svType) { + return convertScalableVectorType(svType, *this); + }); +} + +void ConvertArmSVEToLLVMPass::runOnOperation() { + OwningRewritePatternList patterns; + ArmSVETypeConverter sveConverter(&getContext()); + populateArmSVEToLLVMConversionPatterns(sveConverter, patterns); + populateVectorToLLVMConversionPatterns(sveConverter, patterns); + populateStdToLLVMConversionPatterns(sveConverter, patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalDialect(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } +} + +std::unique_ptr> mlir::createConvertArmSVEToLLVMPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -19,6 +19,7 @@ add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(ArmSVEToLLVM) add_subdirectory(VectorToROCDL) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToSCF) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -27,6 +27,7 @@ namespace LLVM { class LLVMDialect; class LLVMAVX512Dialect; +class LLVMArmSVEDialect; } // end namespace LLVM namespace NVVM { diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Target) add_subdirectory(Tosa) add_subdirectory(Vector) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -49,6 +49,27 @@ MLIRSideEffectInterfaces ) +add_mlir_dialect_library(MLIRLLVMArmSVE + IR/LLVMArmSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMArmSVEIncGen + MLIRLLVMArmSVEConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRNVVMIR IR/NVVMDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMArmSVEDialect.cpp @@ -0,0 +1,31 @@ +//===- LLVMArmSVEDialect.cpp - MLIR LLVMSVE ops implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMArmSVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsAArch64.h" + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMArmSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMArmSVE.cpp.inc" diff --git a/mlir/lib/Dialect/Target/ArmSVE/CMakeLists.txt b/mlir/lib/Dialect/Target/ArmSVE/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Target/ArmSVE/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRArmSVE + IR/ArmSVEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSVE + + DEPENDS + MLIRArmSVEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + MLIRVectorToLLVM + ) diff --git a/mlir/lib/Dialect/Target/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/Target/ArmSVE/IR/ArmSVEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Target/ArmSVE/IR/ArmSVEDialect.cpp @@ -0,0 +1,63 @@ +//===- ArmSVEDialect.cpp - MLIR ARM SVE dialect implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Arm SVE dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Target/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +void arm_sve::ArmSVEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Target/ArmSVE/ArmSVE.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Target/ArmSVE/ArmSVETypes.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVE.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Target/ArmSVE/ArmSVETypes.cpp.inc" + +namespace mlir { +namespace arm_sve { + +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +Type ArmSVEDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + auto genType = generatedTypeParser(getContext(), parser, "vector"); + if (genType != Type()) + return genType; + parser.emitError(typeLoc, "unknown type in Arm SVE dialect"); + return Type(); +} + +void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'arm_sve' type kind"); +} + +} // namespace arm_sve +} // namespace mlir diff --git a/mlir/lib/Dialect/Target/CMakeLists.txt b/mlir/lib/Dialect/Target/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Target/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ArmSVE) diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -94,3 +94,22 @@ MLIRROCDLIR MLIRTargetLLVMIRModuleTranslation ) + +add_mlir_translation_library(MLIRTargetArmSVE + LLVMIR/LLVMArmSVEIntr.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR + + DEPENDS + MLIRLLVMArmSVEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMArmSVE + MLIRLLVMIR + MLIRTargetLLVMIRModuleTranslation + ) diff --git a/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/LLVMArmSVEIntr.cpp @@ -0,0 +1,63 @@ +//===- LLVMArmSVEIntr.cpp - Convert MLIR LLVM dialect to LLVM intrinsics --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR LLVM and Arm SVE dialects +// and LLVM IR with Arm SVE intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Translation.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; + +namespace { +class LLVMArmSVEModuleTranslation : public LLVM::ModuleTranslation { + friend LLVM::ModuleTranslation; + +public: + using LLVM::ModuleTranslation::ModuleTranslation; + +protected: + LogicalResult convertOperation(Operation &opInst, + llvm::IRBuilder<> &builder) override { +#include "mlir/Dialect/LLVMIR/LLVMArmSVEConversions.inc" + + return LLVM::ModuleTranslation::convertOperation(opInst, builder); + } +}; + +std::unique_ptr +translateLLVMArmSVEModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext, + StringRef name) { + return LLVM::ModuleTranslation::translateModule( + m, llvmContext, name); +} +} // end namespace + +namespace mlir { +void registerArmSVEToLLVMIRTranslation() { + TranslateFromMLIRRegistration reg( + "arm-sve-mlir-to-llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateLLVMArmSVEModuleToLLVMIR( + module, llvmContext, "LLVMDialectModule"); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); +} +} // namespace mlir diff --git a/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ArmSVEToLLVM/convert-to-llvm.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s -convert-arm-sve-to-llvm | mlir-opt | FileCheck %s + +func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.sdot + %0 = arm_sve.sdot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.smmla + %0 = arm_sve.smmla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.udot + %0 = arm_sve.udot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) + -> !arm_sve.vector<4xi32> +{ + // CHECK: llvm_arm_sve.ummla + %0 = arm_sve.ummla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @get_vector_scale() -> index +{ + // CHECK: llvm_arm_sve.vscale + %0 = arm_sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Dialect/Target/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/Target/ArmSVE/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Target/ArmSVE/roundtrip.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +func @arm_sve_sdot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.sdot {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32 + %0 = arm_sve.sdot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_smmla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.smmla {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi3 + %0 = arm_sve.smmla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_udot(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.udot {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32 + %0 = arm_sve.udot %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @arm_sve_ummla(%a: !arm_sve.vector<16xi8>, + %b: !arm_sve.vector<16xi8>, + %c: !arm_sve.vector<4xi32>) -> !arm_sve.vector<4xi32> +{ + // CHECK: arm_sve.ummla {{.*}}: !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + %0 = arm_sve.ummla %c, %a, %b : + !arm_sve.vector<16xi8> -> !arm_sve.vector<4xi32> + return %0 : !arm_sve.vector<4xi32> +} + +func @get_vector_scale() -> index +{ + // CHECK: arm_sve.vector_scale : index + %0 = arm_sve.vector_scale : index + return %0 : index +} diff --git a/mlir/test/Target/armsve.mlir b/mlir/test/Target/armsve.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/armsve.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | mlir-translate --arm-sve-mlir-to-llvmir | FileCheck %s + +// CHECK-LABEL: define @arm_sve_sdot +llvm.func @arm_sve_sdot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.sdot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_smmla +llvm.func @arm_sve_smmla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.smmla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_udot +llvm.func @arm_sve_udot(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec +{ + // CHECK: call @llvm.aarch64.sve.udot.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define @arm_sve_ummla +llvm.func @arm_sve_ummla(%arg0: !llvm.vec, + %arg1: !llvm.vec, + %arg2: !llvm.vec) + -> !llvm.vec { + // CHECK: call @llvm.aarch64.sve.ummla.nxv4i32(, !llvm.vec, !llvm.vec) + -> !llvm.vec + llvm.return %0 : !llvm.vec +} + +// CHECK-LABEL: define i64 @get_vector_scale() +llvm.func @get_vector_scale() -> !llvm.i64 { + // CHECK: call i64 @llvm.vscale.i64() + %0 = "llvm_arm_sve.vscale"() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 +}