diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -645,7 +645,7 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the + (AMX, X86Vector, ArmNeon, ArmSVE, RISCVV, etc.) in combination with the architectural-neutral vector dialect lowering. }]; @@ -675,7 +675,11 @@ Option<"enableX86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"enableRISCVV", "enable-riscvv", + "bool", /*default=*/"false", + "Enables the use of RISCVV dialect while lowering the vector " + "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -24,7 +24,7 @@ LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), enableArmNeon(false), enableArmSVE(false), enableAMX(false), - enableX86Vector(false) {} + enableX86Vector(false), enableRISCVV(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -50,6 +50,10 @@ enableX86Vector = b; return *this; } + LowerVectorToLLVMOptions &setEnableRISCVV(bool b) { + enableRISCVV = b; + return *this; + } bool reassociateFPReductions; bool enableIndexOptimizations; @@ -57,6 +61,7 @@ bool enableArmSVE; bool enableAMX; bool enableX86Vector; + bool enableRISCVV; }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td --- a/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td +++ b/mlir/include/mlir/Dialect/ArmSVE/ArmSVE.td @@ -17,6 +17,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/Dialect/ArmSVE/ArmSVEOpBase.td" +include "mlir/Dialect/Vector/ScalableVectorTypes.td" //===----------------------------------------------------------------------===// // ArmSVE dialect definition @@ -47,13 +48,8 @@ }]; } -class ArmSVE_Type : TypeDef { } - -def ScalableVectorType : ArmSVE_Type<"ScalableVector"> { - let mnemonic = "vector"; - - let summary = "Scalable vector type"; - +def ScalableVectorType + : ScalableVector_Type { let description = [{ A type representing scalable length SIMD vectors. Unlike fixed-length SIMD vectors, whose size is constant and known at compile time, scalable diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Quant) +add_subdirectory(RISCVV) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SparseTensor) diff --git a/mlir/include/mlir/Dialect/RISCVV/CMakeLists.txt b/mlir/include/mlir/Dialect/RISCVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(RISCVV riscvv) +add_mlir_doc(RISCVV RISCVV Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS RISCVV.td) +mlir_tablegen(RISCVVConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRRISCVVConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/RISCVV/RISCVV.td b/mlir/include/mlir/Dialect/RISCVV/RISCVV.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/RISCVV.td @@ -0,0 +1,296 @@ +//===-- RISCVV.td - RISCVV dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the RISCVV dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef RISCVV_OPS +#define RISCVV_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Dialect/Vector/ScalableVectorTypes.td" + +//===----------------------------------------------------------------------===// +// RISCVV dialect definition +//===----------------------------------------------------------------------===// + +def RISCVV_Dialect : Dialect { + let name = "riscvv"; + let cppNamespace = "::mlir::riscvv"; + let summary = "Basic dialect to target RISC-V Vector extension"; + let description = [{ + RISC-V vector extension (RISCVV) is the vector instruction set with scalable + vector types, and the RISCVV instructions are vector length agnostic (VLA). + For more details about RISCVV, please see the + [RISCVV specification](https://github.com/riscv/riscv-v-spec). + This dialect contains the definitions of RISCVV operations and RISCVV + intrinsic operations. The former is used to interoperate with higher-level + dialects, and the latter is responsible for mapping to LLVM IR intrinsic. + }]; +} + +//===----------------------------------------------------------------------===// +// RISCVV LMUL type definitions +//===----------------------------------------------------------------------===// + +class RISCVV_LMULType traits = []> + : TypeDef { + let mnemonic = typeMnemonic; + let summary = "The vector register group multiplier (" # multiplier # ")."; + let printer = [{$_printer << "}]# typeMnemonic #[{";}]; + let parser = [{return get}]# name #[{($_ctxt);}]; +} + +def RISCVV_LMUL_MF8 : RISCVV_LMULType<"MF8", "mf8", "1/8"> {} +def RISCVV_LMUL_MF4 : RISCVV_LMULType<"MF4", "mf4", "1/4"> {} +def RISCVV_LMUL_MF2 : RISCVV_LMULType<"MF2", "mf2", "1/2"> {} +def RISCVV_LMUL_M1 : RISCVV_LMULType<"M1", "m1", "1"> {} +def RISCVV_LMUL_M2 : RISCVV_LMULType<"M2", "m2", "2"> {} +def RISCVV_LMUL_M4 : RISCVV_LMULType<"M4", "m4", "4"> {} +def RISCVV_LMUL_M8 : RISCVV_LMULType<"M8", "m8", "8"> {} + +//===----------------------------------------------------------------------===// +// RISCVV mask type definitions +//===----------------------------------------------------------------------===// + +class RISCVV_MaskType traits = []> + : TypeDef { + let mnemonic = typeMnemonic; + let summary = "The mask length (SEW/LMUL = " # maskLength # ") type."; + let printer = [{$_printer << "}]# typeMnemonic #[{";}]; + let parser = [{return get}]# name #[{($_ctxt);}]; +} + +def RISCVV_Mask1 : RISCVV_MaskType<"Mask1", "mask1", "1"> {} +def RISCVV_Mask2 : RISCVV_MaskType<"Mask2", "mask2", "2"> {} +def RISCVV_Mask4 : RISCVV_MaskType<"Mask4", "mask4", "4"> {} +def RISCVV_Mask8 : RISCVV_MaskType<"Mask8", "mask8", "8"> {} +def RISCVV_Mask16 : RISCVV_MaskType<"Mask16", "mask16", "16"> {} +def RISCVV_Mask32 : RISCVV_MaskType<"Mask32", "mask32", "32"> {} +def RISCVV_Mask64 : RISCVV_MaskType<"Mask64", "mask64", "64"> {} + +//===----------------------------------------------------------------------===// +// RISCVV scalable vector type definitions +//===----------------------------------------------------------------------===// + +def ScalableVectorType : ScalableVector_Type { + let description = [{ + RISCVV scalable vector type takes two parameters. The first one is vector + register group multiplier (LMUL) type or mask type. The second one is the + element type, which indicates the selected element width (SEW) setting. + The LMUL and SEW are used to configure scalable vector length at runtime. + }]; + + let parameters = (ins "Type":$sizeType, "Type":$elementType); + + let printer = [{ + $_printer << "vector<" << getImpl()->sizeType << ','; + $_printer << getImpl()->elementType << '>'; + }]; + + let parser = [{ + if ($_parser.parseLess()) return Type(); + Type sizeType; + if ($_parser.parseType(sizeType)) return Type(); + if ($_parser.parseComma()) return Type(); + Type elementType; + if ($_parser.parseType(elementType)) return Type(); + if ($_parser.parseGreater()) return Type(); + return get($_ctxt, sizeType, elementType); + }]; +} +//===----------------------------------------------------------------------===// +// Additional LLVM type constraints +//===----------------------------------------------------------------------===// + +def LLVMScalableVectorType : + Type()">, + "LLVM dialect scalable vector type">; + +def LLVMPointerType : + Type()">, + "LLVM pointer type">; + +//===----------------------------------------------------------------------===// +// RISCVV scalable vector type constraints +//===----------------------------------------------------------------------===// + +def IsScalableVectorTypePred : + CPred<"$_self.isa<::mlir::riscvv::ScalableVectorType>()">; + +class RISCVVScalableVectorOf allowedTypes> : + ContainerType, IsScalableVectorTypePred, + "$_self.cast<::mlir::riscvv::ScalableVectorType>().getElementType()", + "RISCVV scalable vector">; + +//===----------------------------------------------------------------------===// +// RISCVV operation definitions +//===----------------------------------------------------------------------===// + +class RISCVV_Op traits = []> : + Op {} + +def RISCVVLoadOp : RISCVV_Op<"load">, + Arguments<(ins Arg:$base, Index:$index, + AnyInteger:$length)>, + Results<(outs RISCVVScalableVectorOf<[AnyType]>:$result)> { + let summary = "Load scalable vector from memory"; + let description = [{ + Load a slice of memory into scalable vector with the given element length. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $index `]` `,` $length attr-dict `:` " + "type($base) `,` type($result) `,` type($length)"; +} + +def RISCVVStoreOp : RISCVV_Op<"store">, + Arguments<(ins RISCVVScalableVectorOf<[AnyType]>:$value, Arg:$base, Index:$index, AnyInteger:$length)> { + let summary = "Store scalable vector into memory"; + let description = [{ + Store the given element length of a scalable vector on a slice of memory. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$value `,` $base `[` $index `]` `,` $length attr-dict " + "`:` type($value) `,` type($base) `,` type($length)"; +} + +class RISCVV_BinaryAAXNoMask_Op traits = []> : + RISCVV_Op])> { + let summary = op_description # "for RISCVV scalable vectors"; + let description = [{ The `riscvv.}] # mnemonic # [{` operation can be of + vector-vector form or vector-scalar form. It also takes the vector length + value and returns a scalable vector with the result of the }] + # op_description # [{.}]; + let arguments = (ins + RISCVVScalableVectorOf<[AnyInteger]>:$src1, + AnyType:$src2, + AnyInteger:$length + ); + let results = (outs RISCVVScalableVectorOf<[AnyInteger]>:$dst); + let assemblyFormat = "$src1 `,` $src2 `,` $length attr-dict `:` type($src1) " + "`,` type($src2) `,` type($length)"; +} + +class RISCVV_BinaryAAXMask_Op traits = []> : + RISCVV_Op])> { + let summary = op_description # "for RISCVV scalable vectors"; + let description = [{ The `riscvv.}] # mnemonic # [{` operation can be of + vector-vector form or vector-scalar form. It also takes the mask vector, + maskedoff vector, vector length value and returns a scalable vector with + the result of the }] # op_description # [{.}]; + let arguments = (ins + RISCVVScalableVectorOf<[AnyInteger]>:$maskedoff, + RISCVVScalableVectorOf<[AnyInteger]>:$src1, + AnyType:$src2, + RISCVVScalableVectorOf<[I1]>:$mask, + AnyInteger:$length + ); + let results = (outs RISCVVScalableVectorOf<[AnyInteger]>:$dst); + let assemblyFormat = + "$maskedoff `,` $src1 `,` $src2 `,` $mask `,` $length attr-dict `:` " + "type($maskedoff) `,` type($src2) `,` type($mask) `,` type($length)"; +} + +def RISCVVAddOp : RISCVV_BinaryAAXNoMask_Op<"add", "addition">; +def RISCVVSubOp : RISCVV_BinaryAAXNoMask_Op<"sub", "subtraction">; +def RISCVVMulOp : RISCVV_BinaryAAXNoMask_Op<"mul", "multiplication">; +def RISCVVDivOp : RISCVV_BinaryAAXNoMask_Op<"div", "division">; + +def RISCVVMaskedAddOp : RISCVV_BinaryAAXMask_Op<"masked.add", + "masked addition">; +def RISCVVMaskedSubOp : RISCVV_BinaryAAXMask_Op<"masked.sub", + "masked subtraction">; +def RISCVVMaskedMulOp : RISCVV_BinaryAAXMask_Op<"masked.mul", + "masked multiplication">; +def RISCVVMaskedDivOp : RISCVV_BinaryAAXMask_Op<"masked.div", + "masked division">; + +//===----------------------------------------------------------------------===// +// RISCVV intrinsic operation definitions +//===----------------------------------------------------------------------===// + +class RISCVV_USLoad_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[1], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RISCVV_USStore_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[], + /*list overloadedOperands=*/[0, 2], + /*list traits=*/traits, + /*int numResults=*/0>; + +class RISCVV_BinaryAAXNoMask_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[1, 2], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RISCVV_BinaryAAXMask_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[2, 4], + /*list traits=*/traits, + /*int numResults=*/1>; + +def RISCVVIntrLoadEleOp : RISCVV_USLoad_IntrOp<"vle">, + Arguments<(ins LLVMPointerType, AnyInteger)>; +def RISCVVIntrStoreEleOp : RISCVV_USStore_IntrOp<"vse">, + Arguments<(ins LLVMScalableVectorType, LLVMPointerType, AnyInteger)>; + +def RISCVVIntrAddOp : RISCVV_BinaryAAXNoMask_IntrOp<"vadd">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; +def RISCVVIntrSubOp : RISCVV_BinaryAAXNoMask_IntrOp<"vsub">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; +def RISCVVIntrMulOp : RISCVV_BinaryAAXNoMask_IntrOp<"vmul">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; +def RISCVVIntrDivOp : RISCVV_BinaryAAXNoMask_IntrOp<"vdiv">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; + +def RISCVVMaskedIntrAddOp : RISCVV_BinaryAAXMask_IntrOp<"vadd_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; +def RISCVVMaskedIntrSubOp : RISCVV_BinaryAAXMask_IntrOp<"vsub_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; +def RISCVVMaskedIntrMulOp : RISCVV_BinaryAAXMask_IntrOp<"vmul_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; +def RISCVVMaskedIntrDivOp : RISCVV_BinaryAAXMask_IntrOp<"vdiv_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; + +#endif // RISCVV_OPS diff --git a/mlir/include/mlir/Dialect/RISCVV/RISCVVDialect.h b/mlir/include/mlir/Dialect/RISCVV/RISCVVDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/RISCVVDialect.h @@ -0,0 +1,71 @@ +//===- RISCVVDialect.h - MLIR Dialect for RISC-V vector extension --*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_RISCVV_RISCVVDIALECT_H +#define MLIR_DIALECT_RISCVV_RISCVVDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// RISCVVLMULType +//===----------------------------------------------------------------------===// + +/// This RISCVVLMULType represents the vector register group multiplier (LMUL) +/// setting. When the LMUL greater than 1, the multiplier (M1, M2, M4, M8) +/// represents the number of vector registers that are combined to form a +/// vector register group. The multiplier can also be fractional values (MF8, +/// MF4, MF2), which reduces the number of bits used in a vector register. +class RISCVVLMULType : public Type { +public: + using Type::Type; + + static RISCVVLMULType getMF8(MLIRContext *ctx); + static RISCVVLMULType getMF4(MLIRContext *ctx); + static RISCVVLMULType getMF2(MLIRContext *ctx); + static RISCVVLMULType getM1(MLIRContext *ctx); + static RISCVVLMULType getM2(MLIRContext *ctx); + static RISCVVLMULType getM4(MLIRContext *ctx); + static RISCVVLMULType getM8(MLIRContext *ctx); +}; + +//===----------------------------------------------------------------------===// +// RISCVVMaskType +//===----------------------------------------------------------------------===// + +/// This RISCVVMaskType represents the mask length setting. The mask length +/// setting is equal to the ratio of SEW and LMUL (n = SEW/LMUL). +class RISCVVMaskType : public Type { +public: + using Type::Type; + + static RISCVVMaskType getMask1(MLIRContext *ctx); + static RISCVVMaskType getMask2(MLIRContext *ctx); + static RISCVVMaskType getMask4(MLIRContext *ctx); + static RISCVVMaskType getMask8(MLIRContext *ctx); + static RISCVVMaskType getMask16(MLIRContext *ctx); + static RISCVVMaskType getMask32(MLIRContext *ctx); + static RISCVVMaskType getMask64(MLIRContext *ctx); +}; + +} // end namespace mlir + +#include "mlir/Dialect/RISCVV/RISCVVDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/RISCVV/RISCVVTypes.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/RISCVV/RISCVV.h.inc" + +#endif // MLIR_DIALECT_RISCVV_RISCVVDIALECT_H diff --git a/mlir/include/mlir/Dialect/RISCVV/Transforms.h b/mlir/include/mlir/Dialect/RISCVV/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/Transforms.h @@ -0,0 +1,30 @@ +//===- Transforms.h - RISCVV Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_RISCVV_TRANSFORMS_H +#define MLIR_DIALECT_RISCVV_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; + +/// Collect a set of patterns to lower RISCVV ops to ops that map to LLVM +/// intrinsics. +void populateRISCVVLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering RISCVV ops to ops that map to LLVM +/// intrinsics. +void configureRISCVVLegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_RISCVV_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/Vector/ScalableVectorTypes.td b/mlir/include/mlir/Dialect/Vector/ScalableVectorTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/ScalableVectorTypes.td @@ -0,0 +1,23 @@ +//===- ScalableVectorTypes.td - Scalable vector types definitions --------====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines MLIR scalable vector types. +// +//===----------------------------------------------------------------------===// + +#ifndef SCALABLE_VECTOR_TYPES +#define SCALABLE_VECTOR_TYPES + +class ScalableVector_Type + : TypeDef { + let mnemonic = "vector"; + + let summary = "Scalable vector type"; +} + +#endif // SCALABLE_VECTOR_TYPES diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -68,6 +69,7 @@ pdl::PDLDialect, pdl_interp::PDLInterpDialect, quant::QuantizationDialect, + riscvv::RISCVVDialect, spirv::SPIRVDialect, StandardOpsDialect, arm_sve::ArmSVEDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -21,6 +21,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" @@ -38,6 +39,7 @@ registerOpenACCDialectTranslation(registry); registerOpenMPDialectTranslation(registry); registerROCDLDialectTranslation(registry); + registerRISCVVDialectTranslation(registry); registerX86VectorDialectTranslation(registry); } } // namespace mlir diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- RISCVVToLLVMIRTranslation.h - RISCVV to LLVM IR ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for RISCVV dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_RISCVV_RISCVVTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_RISCVV_RISCVVTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the RISCVV dialect and the translation from it to the LLVM IR in +/// the given registry. +void registerRISCVVDialectTranslation(DialectRegistry ®istry); + +/// Register the RISCVV dialect and the translation from it in the registry +/// associated with the given context. +void registerRISCVVDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_RISCVV_RISCVVTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -21,6 +21,8 @@ MLIRLLVMCommonConversion MLIRLLVMIR MLIRMemRef + MLIRRISCVV + MLIRRISCVVTransforms MLIRTargetLLVMIRExport MLIRTransforms MLIRVector diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -19,6 +19,8 @@ #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/Dialect/RISCVV/Transforms.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/X86Vector/Transforms.h" @@ -37,6 +39,7 @@ this->enableArmNeon = options.enableArmNeon; this->enableArmSVE = options.enableArmSVE; this->enableAMX = options.enableAMX; + this->enableRISCVV = options.enableRISCVV; this->enableX86Vector = options.enableX86Vector; } // Override explicitly to allow conditional dialect dependence. @@ -49,6 +52,8 @@ registry.insert(); if (enableAMX) registry.insert(); + if (enableRISCVV) + registry.insert(); if (enableX86Vector) registry.insert(); } @@ -102,6 +107,10 @@ configureX86VectorLegalizeForExportTarget(target); populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); } + if (enableRISCVV) { + configureRISCVVLegalizeForExportTarget(target); + populateRISCVVLegalizeForLLVMExportPatterns(converter, patterns); + } if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Quant) +add_subdirectory(RISCVV) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SparseTensor) diff --git a/mlir/lib/Dialect/RISCVV/CMakeLists.txt b/mlir/lib/Dialect/RISCVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/RISCVV/IR/CMakeLists.txt b/mlir/lib/Dialect/RISCVV/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MLIRRISCVV + RISCVVDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/RISCVV + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/RISCVV/IR/RISCVVDialect.cpp b/mlir/lib/Dialect/RISCVV/IR/RISCVVDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/IR/RISCVVDialect.cpp @@ -0,0 +1,129 @@ +//===- RISCVVDialect.cpp - MLIR RISCVV dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the RISCVV dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +#include "mlir/Dialect/RISCVV/RISCVVDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/RISCVV/RISCVV.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/RISCVV/RISCVVTypes.cpp.inc" + +void riscvv::RISCVVDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/RISCVV/RISCVV.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/RISCVV/RISCVVTypes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// RISCVVLMULType +//===----------------------------------------------------------------------===// + +RISCVVLMULType RISCVVLMULType::getMF8(MLIRContext *ctx) { + return riscvv::MF8Type::get(ctx); +} + +RISCVVLMULType RISCVVLMULType::getMF4(MLIRContext *ctx) { + return riscvv::MF4Type::get(ctx); +} + +RISCVVLMULType RISCVVLMULType::getMF2(MLIRContext *ctx) { + return riscvv::MF2Type::get(ctx); +} + +RISCVVLMULType RISCVVLMULType::getM1(MLIRContext *ctx) { + return riscvv::M1Type::get(ctx); +} + +RISCVVLMULType RISCVVLMULType::getM2(MLIRContext *ctx) { + return riscvv::M2Type::get(ctx); +} + +RISCVVLMULType RISCVVLMULType::getM4(MLIRContext *ctx) { + return riscvv::M4Type::get(ctx); +} + +RISCVVLMULType RISCVVLMULType::getM8(MLIRContext *ctx) { + return riscvv::M8Type::get(ctx); +} + +//===----------------------------------------------------------------------===// +// RISCVVMaskType +//===----------------------------------------------------------------------===// + +RISCVVMaskType RISCVVMaskType::getMask1(MLIRContext *ctx) { + return riscvv::Mask1Type::get(ctx); +} + +RISCVVMaskType RISCVVMaskType::getMask2(MLIRContext *ctx) { + return riscvv::Mask2Type::get(ctx); +} + +RISCVVMaskType RISCVVMaskType::getMask4(MLIRContext *ctx) { + return riscvv::Mask4Type::get(ctx); +} + +RISCVVMaskType RISCVVMaskType::getMask8(MLIRContext *ctx) { + return riscvv::Mask8Type::get(ctx); +} + +RISCVVMaskType RISCVVMaskType::getMask16(MLIRContext *ctx) { + return riscvv::Mask16Type::get(ctx); +} + +RISCVVMaskType RISCVVMaskType::getMask32(MLIRContext *ctx) { + return riscvv::Mask32Type::get(ctx); +} + +RISCVVMaskType RISCVVMaskType::getMask64(MLIRContext *ctx) { + return riscvv::Mask64Type::get(ctx); +} + +//===----------------------------------------------------------------------===// +// Parser and Printer +//===----------------------------------------------------------------------===// + +Type riscvv::RISCVVDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + StringRef mnemonic; + parser.parseKeyword(&mnemonic); + { + Type genType; + auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), + parser, mnemonic, genType); + if (parseResult.hasValue()) + return genType; + } + parser.emitError(typeLoc, "unknown type in RISCVV dialect"); + return Type(); +} + +void riscvv::RISCVVDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'riscvv' type kind"); +} diff --git a/mlir/lib/Dialect/RISCVV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/RISCVV/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRRISCVVTransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRRISCVVConversionsIncGen + + LINK_LIBS PUBLIC + MLIRRISCVV + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/RISCVV/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/RISCVV/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,315 @@ +//===- LegalizeForLLVMExport.cpp - Prepare RISCVV for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/Dialect/RISCVV/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::riscvv; + +// Extract an LLVM IR type from the LLVM IR dialect type. +static Type unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + if (!LLVM::isCompatibleType(type)) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return type; +} + +// Scalable vector type in RISCVV dialect uses LMUL and SEW as parameters to +// provide better semantics. This is the helper function to bridge the gap +// of scalable vector type between the RISCVV dialect and LLVM dialect. +unsigned typeMapping(ScalableVectorType riscvvSVType) { + auto elementType = riscvvSVType.getElementType(); + auto *elementContext = elementType.getContext(); + auto sizeType = riscvvSVType.getSizeType(); + auto *sizeContext = sizeType.getContext(); + // TODO: support more element type. + if (elementType.isa()) { + // Mapping LMUL and Mask type for different SEW type. + switch (elementType.cast().getWidth()) { + case 64: + if (sizeType.isa() || sizeType.isa() || + sizeType.isa()) { + emitError(UnknownLoc::get(sizeContext), "unsupported LMUL Type for ") + << elementType << " type."; + } + return llvm::TypeSwitch(sizeType) + .Case([&](Type) { return 1; }) + .Case([&](Type) { return 2; }) + .Case([&](Type) { return 4; }) + .Case([&](Type) { return 8; }) + .Default([](Type) -> unsigned { + llvm_unreachable("incompatible with RISC-V vector type"); + }); + break; + case 32: + if (sizeType.isa() || sizeType.isa()) { + emitError(UnknownLoc::get(sizeContext), "unsupported LMUL Type for ") + << elementType << " type."; + } + return llvm::TypeSwitch(sizeType) + .Case([&](Type) { return 1; }) + .Case([&](Type) { return 2; }) + .Case([&](Type) { return 4; }) + .Case([&](Type) { return 8; }) + .Case([&](Type) { return 16; }) + .Default([](Type) -> unsigned { + llvm_unreachable("incompatible with RISC-V vector type"); + }); + break; + case 16: + if (sizeType.isa()) { + emitError(UnknownLoc::get(sizeContext), "unsupported LMUL type for ") + << elementType << " type."; + } + return llvm::TypeSwitch(sizeType) + .Case([&](Type) { return 1; }) + .Case([&](Type) { return 2; }) + .Case([&](Type) { return 4; }) + .Case([&](Type) { return 8; }) + .Case([&](Type) { return 16; }) + .Case([&](Type) { return 32; }) + .Default([](Type) -> unsigned { + llvm_unreachable("incompatible with RISC-V vector type"); + }); + break; + case 8: + return llvm::TypeSwitch(sizeType) + .Case([&](Type) { return 1; }) + .Case([&](Type) { return 2; }) + .Case([&](Type) { return 4; }) + .Case([&](Type) { return 8; }) + .Case([&](Type) { return 16; }) + .Case([&](Type) { return 32; }) + .Case([&](Type) { return 64; }) + .Default([](Type) -> unsigned { + llvm_unreachable("incompatible with RISC-V vector type"); + }); + break; + case 1: + return llvm::TypeSwitch(sizeType) + .Case([&](Type) { return 1; }) + .Case([&](Type) { return 2; }) + .Case([&](Type) { return 4; }) + .Case([&](Type) { return 8; }) + .Case([&](Type) { return 16; }) + .Case([&](Type) { return 32; }) + .Case([&](Type) { return 64; }) + .Default([](Type) -> unsigned { + llvm_unreachable("incompatible with RISC-V vector type"); + }); + break; + default: + emitError(UnknownLoc::get(elementContext), "unsupported ") + << elementType << " SEW type."; + } + } else { + emitError(UnknownLoc::get(elementContext), "unsupported ") + << elementType << " SEW type."; + } + return 0; +} + +static Optional +convertScalableVectorTypeToLLVM(ScalableVectorType svType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(svType.getElementType())); + if (!elementType) + return {}; + auto sVectorType = + LLVM::LLVMScalableVectorType::get(elementType, typeMapping(svType)); + return sVectorType; +} + +template +class ForwardOperands : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) + return rewriter.notifyMatchFailure(op, "operand types already match"); + + rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + return success(); + } +}; + +struct RISCVVLoadOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RISCVVLoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto type = loadOp.getMemRefType(); + if (!isConvertibleAndHasIdentityMaps(type)) + return failure(); + + RISCVVLoadOp::Adaptor transformed(operands); + LLVMTypeConverter converter(loadOp.getContext()); + + auto resultType = loadOp.result().getType(); + LLVM::LLVMPointerType llvmDataTypePtr; + if (resultType.isa()) { + llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType.cast()); + } else if (resultType.isa()) { + llvmDataTypePtr = LLVM::LLVMPointerType::get( + convertScalableVectorTypeToLLVM(resultType.cast(), + converter) + .getValue()); + } + Value dataPtr = + getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), + transformed.index(), rewriter); + Value bitCastedPtr = rewriter.create( + loadOp.getLoc(), llvmDataTypePtr, dataPtr); + Value vl = loadOp.getOperand(2); + rewriter.replaceOpWithNewOp( + loadOp, + convertScalableVectorTypeToLLVM(resultType.cast(), + converter) + .getValue(), + bitCastedPtr, vl); + return success(); + } +}; + +struct RISCVVStoreOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RISCVVStoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto type = storeOp.getMemRefType(); + if (!isConvertibleAndHasIdentityMaps(type)) + return failure(); + + RISCVVStoreOp::Adaptor transformed(operands); + LLVMTypeConverter converter(storeOp.getContext()); + + auto resultType = storeOp.value().getType(); + LLVM::LLVMPointerType llvmDataTypePtr; + if (resultType.isa()) { + llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType.cast()); + } else if (resultType.isa()) { + llvmDataTypePtr = LLVM::LLVMPointerType::get( + convertScalableVectorTypeToLLVM(resultType.cast(), + converter) + .getValue()); + } + Value dataPtr = + getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), + transformed.index(), rewriter); + Value bitCastedPtr = rewriter.create( + storeOp.getLoc(), llvmDataTypePtr, dataPtr); + Value vl = storeOp.getOperand(3); + rewriter.replaceOpWithNewOp( + storeOp, transformed.value(), bitCastedPtr, vl); + return success(); + } +}; + +using RISCVVAddOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVSubOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVMulOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVDivOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVMaskedAddOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVMaskedSubOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVMaskedMulOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVMaskedDivOpLowering = + OneToOneConvertToLLVMPattern; + +/// Populate the given list with patterns that convert from RISCVV to LLVM. +void mlir::populateRISCVVLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // Populate conversion patterns. + // Remove any RISCVV-specific types from function signatures and results. + populateFuncOpTypeConversionPattern(patterns, converter); + converter.addConversion([&converter](ScalableVectorType riscvvSVType) { + return convertScalableVectorTypeToLLVM(riscvvSVType, converter); + }); + + // clang-format off + patterns.add, + ForwardOperands, + ForwardOperands>(converter, &converter.getContext()); + patterns.add(converter); + patterns.add(converter); + // clang-format on +} + +void mlir::configureRISCVVLegalizeForExportTarget( + LLVMConversionTarget &target) { + // clang-format off + target.addLegalOp(); + target.addIllegalOp(); + // clang-format on + + auto hasScalableVectorType = [](TypeRange types) { + for (Type type : types) + if (type.isa()) + return true; + return false; + }; + target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { + return !hasScalableVectorType(op.getType().getInputs()) && + !hasScalableVectorType(op.getType().getResults()); + }); + target.addDynamicallyLegalOp( + [hasScalableVectorType](Operation *op) { + return !hasScalableVectorType(op->getOperandTypes()) && + !hasScalableVectorType(op->getResultTypes()); + }); +} diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -39,6 +39,7 @@ MLIRArmNeonToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation + MLIRRISCVVToLLVMIRTranslation MLIRX86VectorToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(RISCVV) add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/RISCVV/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRRISCVVToLLVMIRTranslation + RISCVVToLLVMIRTranslation.cpp + + DEPENDS + MLIRRISCVVConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRRISCVV + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.cpp @@ -0,0 +1,55 @@ +//======- RISCVVToLLVMIRTranslation.cpp - Translate RISCVV to LLVM IR ----====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the RISCVV dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsRISCV.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the RISCVV dialect to LLVM IR. +class RISCVVDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/RISCVV/RISCVVConversions.inc" + + return failure(); + } +}; +} // end namespace + +void mlir::registerRISCVVDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerRISCVVDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerRISCVVDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/test/Dialect/RISCVV/legalize-for-llvm.mlir b/mlir/test/Dialect/RISCVV/legalize-for-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/RISCVV/legalize-for-llvm.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-riscvv" -convert-std-to-llvm | mlir-opt | FileCheck %s + +func @riscvv_memory(%v: !riscvv.vector, + %m: memref, + %vl: i64) -> !riscvv.vector { + %c0 = constant 0 : index + // CHECK: llvm.extractvalue {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: llvm.bitcast {{.*}} : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: riscvv.intr.vle{{.*}} : (!llvm.ptr>, i64) -> !llvm.vec + %0 = riscvv.load %m[%c0], %vl : memref, !riscvv.vector, i64 + // CHECK: llvm.extractvalue {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: llvm.bitcast {{.*}} : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: riscvv.intr.vse{{.*}} : (!llvm.vec, !llvm.ptr>, i64) -> () + riscvv.store %v, %m[%c0], %vl : !riscvv.vector, memref, i64 + return %0 : !riscvv.vector +} + +func @riscvv_arith(%a: !riscvv.vector, + %b: !riscvv.vector, + %c: i32, + %vl: i64) -> !riscvv.vector { + // CHECK: riscvv.intr.vadd{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %0 = riscvv.add %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vadd{{.*}} : (!llvm.vec, i32, i64) -> !llvm.vec + %1 = riscvv.add %a, %c, %vl : !riscvv.vector, i32, i64 + // CHECK: riscvv.intr.vsub{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %2 = riscvv.sub %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vsub{{.*}} : (!llvm.vec, i32, i64) -> !llvm.vec + %3 = riscvv.sub %a, %c, %vl : !riscvv.vector, i32, i64 + // CHECK: riscvv.intr.vmul{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %4 = riscvv.mul %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vmul{{.*}} : (!llvm.vec, i32, i64) -> !llvm.vec + %5 = riscvv.mul %a, %c, %vl : !riscvv.vector, i32, i64 + // CHECK: riscvv.intr.vdiv{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %6 = riscvv.div %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vdiv{{.*}} : (!llvm.vec, i32, i64) -> !llvm.vec + %7 = riscvv.div %a, %c, %vl : !riscvv.vector, i32, i64 + return %7 : !riscvv.vector +} + +func @riscvv_masked_arith(%maskedoff: !riscvv.vector, + %a: !riscvv.vector, + %b: !riscvv.vector, + %c: i32, + %mask: !riscvv.vector, + %vl: i64) -> !riscvv.vector { + // CHECK: riscvv.intr.vadd_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %0 = riscvv.masked.add %maskedoff, %a, %b, %mask, %vl: !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vadd_mask{{.*}} : (!llvm.vec, !llvm.vec, i32, !llvm.vec, i64) -> !llvm.vec + %1 = riscvv.masked.add %maskedoff, %a, %c, %mask, %vl: !riscvv.vector, i32, !riscvv.vector, i64 + // CHECK: riscvv.intr.vsub_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %2 = riscvv.masked.sub %maskedoff, %a, %b, %mask, %vl: !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vsub_mask{{.*}} : (!llvm.vec, !llvm.vec, i32, !llvm.vec, i64) -> !llvm.vec + %3 = riscvv.masked.sub %maskedoff, %a, %c, %mask, %vl: !riscvv.vector, i32, !riscvv.vector, i64 + // CHECK: riscvv.intr.vmul_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %4 = riscvv.masked.mul %maskedoff, %a, %b, %mask, %vl: !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vmul_mask{{.*}} : (!llvm.vec, !llvm.vec, i32, !llvm.vec, i64) -> !llvm.vec + %5 = riscvv.masked.mul %maskedoff, %a, %c, %mask, %vl: !riscvv.vector, i32, !riscvv.vector, i64 + // CHECK: riscvv.intr.vdiv_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %6 = riscvv.masked.div %maskedoff, %a, %b, %mask, %vl: !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.intr.vdiv_mask{{.*}} : (!llvm.vec, !llvm.vec, i32, !llvm.vec, i64) -> !llvm.vec + %7 = riscvv.masked.div %maskedoff, %a, %c, %mask, %vl: !riscvv.vector, i32, !riscvv.vector, i64 + return %7 : !riscvv.vector +} diff --git a/mlir/test/Dialect/RISCVV/roundtrip.mlir b/mlir/test/Dialect/RISCVV/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/RISCVV/roundtrip.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @riscvv_memory +func @riscvv_memory(%v: !riscvv.vector, + %m: memref, + %vl: i64) -> !riscvv.vector { + %c0 = constant 0 : index + // CHECK: riscvv.load {{.*}}: memref, !riscvv.vector, i64 + %0 = riscvv.load %m[%c0], %vl : memref, !riscvv.vector, i64 + // CHECK: riscvv.store {{.*}}: !riscvv.vector, memref, i64 + riscvv.store %v, %m[%c0], %vl : !riscvv.vector, memref, i64 + return %0 : !riscvv.vector +} + +// CHECK-LABEL: func @riscvv_arith +func @riscvv_arith(%a: !riscvv.vector, + %b: !riscvv.vector, + %c: i32, + %vl: i64) -> !riscvv.vector { + // CHECK: riscvv.add {{.*}} : !riscvv.vector, !riscvv.vector, i64 + %0 = riscvv.add %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.add {{.*}} : !riscvv.vector, i32, i64 + %1 = riscvv.add %a, %c, %vl : !riscvv.vector, i32, i64 + // CHECK: riscvv.sub {{.*}} : !riscvv.vector, !riscvv.vector, i64 + %2 = riscvv.sub %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.sub {{.*}} : !riscvv.vector, i32, i64 + %3 = riscvv.sub %a, %c, %vl : !riscvv.vector, i32, i64 + // CHECK: riscvv.mul {{.*}} : !riscvv.vector, !riscvv.vector, i64 + %4 = riscvv.mul %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.mul {{.*}} : !riscvv.vector, i32, i64 + %5 = riscvv.mul %a, %c, %vl : !riscvv.vector, i32, i64 + // CHECK: riscvv.div {{.*}} : !riscvv.vector, !riscvv.vector, i64 + %6 = riscvv.div %a, %b, %vl : !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.div {{.*}} : !riscvv.vector, i32, i64 + %7 = riscvv.div %a, %c, %vl : !riscvv.vector, i32, i64 + return %7 : !riscvv.vector +} + +// CHECK-LABEL: func @riscvv_masked_arith +func @riscvv_masked_arith(%maskedoff: !riscvv.vector, + %a: !riscvv.vector, + %b: !riscvv.vector, + %c: i32, + %mask: !riscvv.vector, + %vl: i64) -> !riscvv.vector { + // CHECK: riscvv.masked.add {{.*}} : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + %0 = riscvv.masked.add %maskedoff, %a, %b, %mask, %vl : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.masked.add {{.*}} : !riscvv.vector, i32, !riscvv.vector, i64 + %1 = riscvv.masked.add %maskedoff, %a, %c, %mask, %vl : !riscvv.vector, i32, !riscvv.vector, i64 + // CHECK: riscvv.masked.sub {{.*}} : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + %2 = riscvv.masked.sub %maskedoff, %a, %b, %mask, %vl : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.masked.sub {{.*}} : !riscvv.vector, i32, !riscvv.vector, i64 + %3 = riscvv.masked.sub %maskedoff, %a, %c, %mask, %vl : !riscvv.vector, i32, !riscvv.vector, i64 + // CHECK: riscvv.masked.mul {{.*}} : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + %4 = riscvv.masked.mul %maskedoff, %a, %b, %mask, %vl : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.masked.mul {{.*}} : !riscvv.vector, i32, !riscvv.vector, i64 + %5 = riscvv.masked.mul %maskedoff, %a, %c, %mask, %vl : !riscvv.vector, i32, !riscvv.vector, i64 + // CHECK: riscvv.masked.div {{.*}} : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + %6 = riscvv.masked.div %maskedoff, %a, %b, %mask, %vl : !riscvv.vector, !riscvv.vector, !riscvv.vector, i64 + // CHECK: riscvv.masked.div {{.*}} : !riscvv.vector, i32, !riscvv.vector, i64 + %7 = riscvv.masked.div %maskedoff, %a, %c, %mask, %vl : !riscvv.vector, i32, !riscvv.vector, i64 + return %7 : !riscvv.vector +} diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -20,6 +20,7 @@ // CHECK-NEXT: pdl // CHECK-NEXT: pdl_interp // CHECK-NEXT: quant +// CHECK-NEXT: riscvv // CHECK-NEXT: rocdl // CHECK-NEXT: scf // CHECK-NEXT: shape