diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -808,7 +808,7 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the + (AMX, X86Vector, ArmNeon, ArmSVE, RISCVV, etc.) in combination with the architectural-neutral vector dialect lowering. }]; @@ -838,7 +838,11 @@ Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"riscvv", "enable-riscvv", + "bool", /*default=*/"false", + "Enables the use of RISCVV dialect while lowering the vector " + "dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -47,6 +47,10 @@ x86Vector = b; return *this; } + LowerVectorToLLVMOptions &enableRISCVV(bool b = true) { + riscvv = b; + return *this; + } bool reassociateFPReductions{false}; bool force32BitVectorIndices{true}; @@ -54,6 +58,7 @@ bool armSVE{false}; bool amx{false}; bool x86Vector{false}; + bool riscvv{false}; }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Quant) +add_subdirectory(RISCVV) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SparseTensor) diff --git a/mlir/include/mlir/Dialect/RISCVV/CMakeLists.txt b/mlir/include/mlir/Dialect/RISCVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(RISCVV riscvv) +add_mlir_doc(RISCVV RISCVV Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS RISCVV.td) +mlir_tablegen(RISCVVConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRRISCVVConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/RISCVV/RISCVV.td b/mlir/include/mlir/Dialect/RISCVV/RISCVV.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/RISCVV.td @@ -0,0 +1,228 @@ +//===-- RISCVV.td - RISCVV dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the RISCVV dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef RISCVV_OPS +#define RISCVV_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// RISCVV dialect definition +//===----------------------------------------------------------------------===// + +def RISCVV_Dialect : Dialect { + let name = "riscvv"; + let cppNamespace = "::mlir::riscvv"; + let summary = "Basic dialect to target RISC-V Vector extension"; + let description = [{ + RISC-V vector extension (RISCVV) is the vector instruction set with scalable + vector types, and the RISCVV instructions are vector length agnostic (VLA). + For more details about RISCVV, please see the + [RISCVV specification](https://github.com/riscv/riscv-v-spec). + This dialect contains the definitions of RISCVV operations and RISCVV + intrinsic operations. The former is used to interoperate with higher-level + dialects, and the latter is responsible for mapping to LLVM IR intrinsic. + }]; +} + +//===----------------------------------------------------------------------===// +// RISCVV operation definitions +//===----------------------------------------------------------------------===// + +class RISCVV_Op traits = []> : + Op {} + +def RISCVVSetVlOp : + RISCVV_Op<"setvl", + !listconcat([], [AllTypesMatch<["avl", "sew", "lmul", "vl"]>])>, + Arguments<(ins Index:$avl, Index:$sew, Index:$lmul)>, + Results<(outs Index:$vl)> { + let summary = "Set vector length according to AVL, SEW, and LMUL"; + let description = [{ + SetVl operation sets the vector length according to AVL, SEW, and LMUL. + RISC-V vector extension uses this to achieve a direct and portable + strip-mining approach, which is purposed to handle a large number of + elements. The return value of this operation is the number of elements + for a single iteration. + }]; + let assemblyFormat = "$avl `,` $sew `,` $lmul attr-dict `:` type($avl)"; +} + +def RISCVVLoadOp : RISCVV_Op<"load">, + Arguments<(ins Arg:$base, Index:$index, + Index:$length)>, + Results<(outs ScalableVectorOf<[AnyType]>:$result)> { + let summary = "Load scalable vector from memory"; + let description = [{ + Load a slice of memory into scalable vector with the given element length. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $index `]` `,` $length attr-dict `:` " + "type($base) `,` type($result) `,` type($length)"; +} + +def RISCVVStoreOp : RISCVV_Op<"store">, + Arguments<(ins ScalableVectorOf<[AnyType]>:$value, Arg:$base, Index:$index, Index:$length)> { + let summary = "Store scalable vector into memory"; + let description = [{ + Store the given element length of a scalable vector on a slice of memory. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$value `,` $base `[` $index `]` `,` $length attr-dict " + "`:` type($value) `,` type($base) `,` type($length)"; +} + +class RISCVV_BinaryAAXNoMask_Op traits = []> : + RISCVV_Op])> { + let summary = op_description # "for RISCVV scalable vectors"; + let description = [{ The `riscvv.}] # mnemonic # [{` operation can be of + vector-vector form or vector-scalar form. It also takes the vector length + value and returns a scalable vector with the result of the }] + # op_description # [{.}]; + let arguments = (ins + ScalableVectorOf<[AnyInteger]>:$src1, + AnyType:$src2, + Index:$length + ); + let results = (outs ScalableVectorOf<[AnyInteger]>:$dst); + let assemblyFormat = "$src1 `,` $src2 `,` $length attr-dict `:` type($src1) " + "`,` type($src2) `,` type($length)"; +} + +class RISCVV_BinaryAAXMask_Op traits = []> : + RISCVV_Op])> { + let summary = op_description # "for RISCVV scalable vectors"; + let description = [{ The `riscvv.}] # mnemonic # [{` operation can be of + vector-vector form or vector-scalar form. It also takes the mask vector, + maskedoff vector, vector length value and returns a scalable vector with + the result of the }] # op_description # [{.}]; + let arguments = (ins + ScalableVectorOf<[AnyInteger]>:$maskedoff, + ScalableVectorOf<[AnyInteger]>:$src1, + AnyType:$src2, + ScalableVectorOf<[I1]>:$mask, + Index:$length, + OptionalAttr:$policy + ); + let results = (outs ScalableVectorOf<[AnyInteger]>:$dst); + let assemblyFormat = + "$maskedoff `,` $src1 `,` $src2 `,` $mask `,` $length attr-dict " + "`:` type($maskedoff) `,` type($src2) `,` type($mask) `,` type($length)"; +} + +def RISCVVAddOp : RISCVV_BinaryAAXNoMask_Op<"add", "addition">; +def RISCVVSubOp : RISCVV_BinaryAAXNoMask_Op<"sub", "subtraction">; +def RISCVVMulOp : RISCVV_BinaryAAXNoMask_Op<"mul", "multiplication">; +def RISCVVDivOp : RISCVV_BinaryAAXNoMask_Op<"div", "division">; + +def RISCVVMaskedAddOp : RISCVV_BinaryAAXMask_Op<"masked.add", + "masked addition">; +def RISCVVMaskedSubOp : RISCVV_BinaryAAXMask_Op<"masked.sub", + "masked subtraction">; +def RISCVVMaskedMulOp : RISCVV_BinaryAAXMask_Op<"masked.mul", + "masked multiplication">; +def RISCVVMaskedDivOp : RISCVV_BinaryAAXMask_Op<"masked.div", + "masked division">; + +//===----------------------------------------------------------------------===// +// RISCVV intrinsic operation definitions +//===----------------------------------------------------------------------===// + +class RISCVV_VSetVlI_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RISCVV_USLoad_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[2], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RISCVV_USStore_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[], + /*list overloadedOperands=*/[0, 2], + /*list traits=*/traits, + /*int numResults=*/0>; + +class RISCVV_BinaryAAXNoMask_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[2, 3], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RISCVV_BinaryAAXMask_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[2, 4], + /*list traits=*/traits, + /*int numResults=*/1>; + +def RISCVVIntrSetVlIOp : RISCVV_VSetVlI_IntrOp<"vsetvli">, + Arguments<(ins AnyInteger, AnyInteger, AnyInteger)>; + +def RISCVVIntrLoadEleOp : RISCVV_USLoad_IntrOp<"vle">, + Arguments<(ins AnyScalableVector, LLVM_AnyPointer, AnyInteger)>; +def RISCVVIntrStoreEleOp : RISCVV_USStore_IntrOp<"vse">, + Arguments<(ins AnyScalableVector, LLVM_AnyPointer, AnyInteger)>; + +def RISCVVIntrAddOp : RISCVV_BinaryAAXNoMask_IntrOp<"vadd">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, AnyInteger)>; +def RISCVVIntrSubOp : RISCVV_BinaryAAXNoMask_IntrOp<"vsub">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, AnyInteger)>; +def RISCVVIntrMulOp : RISCVV_BinaryAAXNoMask_IntrOp<"vmul">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, AnyInteger)>; +def RISCVVIntrDivOp : RISCVV_BinaryAAXNoMask_IntrOp<"vdiv">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, AnyInteger)>; + +def RISCVVMaskedIntrAddOp : RISCVV_BinaryAAXMask_IntrOp<"vadd_mask">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, + AnyScalableVector, AnyInteger, AnyInteger)>; +def RISCVVMaskedIntrSubOp : RISCVV_BinaryAAXMask_IntrOp<"vsub_mask">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, + AnyScalableVector, AnyInteger, AnyInteger)>; +def RISCVVMaskedIntrMulOp : RISCVV_BinaryAAXMask_IntrOp<"vmul_mask">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, + AnyScalableVector, AnyInteger, AnyInteger)>; +def RISCVVMaskedIntrDivOp : RISCVV_BinaryAAXMask_IntrOp<"vdiv_mask">, + Arguments<(ins AnyScalableVector, AnyScalableVector, AnyType, + AnyScalableVector, AnyInteger, AnyInteger)>; + +#endif // RISCVV_OPS diff --git a/mlir/include/mlir/Dialect/RISCVV/RISCVVDialect.h b/mlir/include/mlir/Dialect/RISCVV/RISCVVDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/RISCVVDialect.h @@ -0,0 +1,23 @@ +//===- RISCVVDialect.h - MLIR Dialect for RISC-V vector extension ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_RISCVV_RISCVVDIALECT_H +#define MLIR_DIALECT_RISCVV_RISCVVDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/RISCVV/RISCVV.h.inc" + +#endif // MLIR_DIALECT_RISCVV_RISCVVDIALECT_H diff --git a/mlir/include/mlir/Dialect/RISCVV/Transforms.h b/mlir/include/mlir/Dialect/RISCVV/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RISCVV/Transforms.h @@ -0,0 +1,30 @@ +//===- Transforms.h - RISCVV Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_RISCVV_TRANSFORMS_H +#define MLIR_DIALECT_RISCVV_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; + +/// Collect a set of patterns to lower RISCVV ops to ops that map to LLVM +/// intrinsics. +void populateRISCVVLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering RISCVV ops to ops that map to LLVM +/// intrinsics. +void configureRISCVVLegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_RISCVV_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -40,6 +40,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" #include "mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" @@ -82,6 +83,7 @@ pdl::PDLDialect, pdl_interp::PDLInterpDialect, quant::QuantizationDialect, + riscvv::RISCVVDialect, spirv::SPIRVDialect, arm_sve::ArmSVEDialect, vector::VectorDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -21,6 +21,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" @@ -37,6 +38,7 @@ registerNVVMDialectTranslation(registry); registerOpenACCDialectTranslation(registry); registerOpenMPDialectTranslation(registry); + registerRISCVVDialectTranslation(registry); registerROCDLDialectTranslation(registry); registerX86VectorDialectTranslation(registry); } diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- RISCVVToLLVMIRTranslation.h - RISCVV to LLVM IR ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for RISCVV dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_RISCVV_RISCVVTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_RISCVV_RISCVVTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the RISCVV dialect and the translation from it to the LLVM IR in +/// the given registry. +void registerRISCVVDialectTranslation(DialectRegistry ®istry); + +/// Register the RISCVV dialect and the translation from it in the registry +/// associated with the given context. +void registerRISCVVDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_RISCVV_RISCVVTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -22,6 +22,8 @@ MLIRLLVMCommonConversion MLIRLLVMIR MLIRMemRef + MLIRRISCVV + MLIRRISCVVTransforms MLIRTargetLLVMIRExport MLIRTransforms MLIRVector diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -21,6 +21,8 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/Dialect/RISCVV/Transforms.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" @@ -38,6 +40,7 @@ this->armNeon = options.armNeon; this->armSVE = options.armSVE; this->amx = options.amx; + this->riscvv = options.riscvv; this->x86Vector = options.x86Vector; } // Override explicitly to allow conditional dialect dependence. @@ -51,6 +54,8 @@ registry.insert(); if (amx) registry.insert(); + if (riscvv) + registry.insert(); if (x86Vector) registry.insert(); } @@ -107,6 +112,10 @@ configureX86VectorLegalizeForExportTarget(target); populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); } + if (riscvv) { + configureRISCVVLegalizeForExportTarget(target); + populateRISCVVLegalizeForLLVMExportPatterns(converter, patterns); + } if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -20,6 +20,7 @@ add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Quant) +add_subdirectory(RISCVV) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SparseTensor) diff --git a/mlir/lib/Dialect/RISCVV/CMakeLists.txt b/mlir/lib/Dialect/RISCVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/RISCVV/IR/CMakeLists.txt b/mlir/lib/Dialect/RISCVV/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MLIRRISCVV + RISCVVDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/RISCVV + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/RISCVV/IR/RISCVVDialect.cpp b/mlir/lib/Dialect/RISCVV/IR/RISCVVDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/IR/RISCVVDialect.cpp @@ -0,0 +1,34 @@ +//===- RISCVVDialect.cpp - MLIR RISCVV dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the RISCVV dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +#include "mlir/Dialect/RISCVV/RISCVVDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/RISCVV/RISCVV.cpp.inc" + +void riscvv::RISCVVDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/RISCVV/RISCVV.cpp.inc" + >(); +} diff --git a/mlir/lib/Dialect/RISCVV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/RISCVV/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRRISCVVTransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRRISCVVConversionsIncGen + + LINK_LIBS PUBLIC + MLIRRISCVV + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/RISCVV/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/RISCVV/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RISCVV/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,284 @@ +//===- LegalizeForLLVMExport.cpp - Prepare RISCVV for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/Dialect/RISCVV/Transforms.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::riscvv; + +/// There are two strategies to control the mask/tail policy in the RISC-V LLVM +/// IR intrinsics: +/// - the `policy` argument at the end of the argument list. +/// - the `passthru` argument at the beginning of the argument list. +/// +/// The following two patterns (`ConvertPolicyOperandOpToLLVMPattern` and +/// `ConvertPassthruOperandOpToLLVMPattern`) are designed for the two +/// strategies, respectively. + +template +class ConvertPolicyOperandOpToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + /// This pattern converts the `policy` attribute to a value, appends the + /// `policy` value to the operand list, and creates the intrinsic operation. + LogicalResult + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + unsigned numResults = op->getNumResults(); + Type packedType; + ValueRange operands = adaptor.getOperands(); + SmallVector operandsVector(operands); + // Get the type of the `vl` value. + Type vlType = operands.back().getType(); + auto attrs = op->getAttrs(); + if (attrs.empty()) { + // Default attribute for the policy setting (policy = 1). + // Add the policy = 1 to the operand list. + Attribute policyDefaultAttr = rewriter.getIntegerAttr( + vlType, APInt(vlType.cast().getWidth(), 1)); + Value policyDefaultValue = + rewriter.create(loc, vlType, policyDefaultAttr); + operandsVector.push_back(policyDefaultValue); + } else if (attrs.size() == 1) { + // Add the policy to the operand list according to the attribute value. + Attribute attr = attrs[0].getValue(); + IntegerAttr policyAttr = attr.cast(); + Value policyValue = + rewriter.create(loc, vlType, policyAttr); + operandsVector.push_back(policyValue); + } else { + return failure(); + } + + LLVMTypeConverter typeConverter = *this->getTypeConverter(); + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + // Create the intrinsic operation. + OperationState state(op->getLoc(), TargetOp::getOperationName()); + state.addTypes(packedType); + state.addOperands(operandsVector); + Operation *newOp = rewriter.create(state); + return rewriter.replaceOp(op, newOp->getResult(0)), success(); + } +}; + +template +class ConvertPassthruOperandOpToLLVMPattern + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + /// This pattern creates an `undef` operation, inserts the `undef` + /// operation to the beginning of the operand list, and creates the intrinsic + /// operation. + LogicalResult + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + unsigned numResults = op->getNumResults(); + auto resultType = op->getResultTypes(); + Type packedType; + ValueRange operands = adaptor.getOperands(); + SmallVector operandsVector(operands); + Value passthru = rewriter.create(loc, resultType[0]); + operandsVector.insert(operandsVector.begin(), passthru); + + LLVMTypeConverter typeConverter = *this->getTypeConverter(); + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + // Create the intrinsic operation. + OperationState state(loc, TargetOp::getOperationName()); + state.addTypes(packedType); + state.addOperands(operandsVector); + Operation *newOp = rewriter.create(state); + return rewriter.replaceOp(op, newOp->getResult(0)), success(); + } +}; + +template +class ForwardOperands : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) + return rewriter.notifyMatchFailure(op, "operand types already match"); + + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +class ReturnOpTypeConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + +struct RISCVVLoadOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RISCVVLoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = loadOp.getMemRefType(); + if (!isConvertibleAndHasIdentityMaps(type)) + return failure(); + + LLVMTypeConverter converter(loadOp.getContext()); + + auto resultType = loadOp.result().getType(); + Value passthru = + rewriter.create(loadOp.getLoc(), resultType); + LLVM::LLVMPointerType llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType); + Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(), + adaptor.index(), rewriter); + Value bitCastedPtr = rewriter.create( + loadOp.getLoc(), llvmDataTypePtr, dataPtr); + Value vl = loadOp.getOperand(2); + Value vlCast = rewriter + .create( + loadOp.getLoc(), rewriter.getI64Type(), vl) + .getResult(0); + rewriter.replaceOpWithNewOp( + loadOp, resultType, passthru, bitCastedPtr, vlCast); + return success(); + } +}; + +struct RISCVVStoreOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RISCVVStoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = storeOp.getMemRefType(); + if (!isConvertibleAndHasIdentityMaps(type)) + return failure(); + + LLVMTypeConverter converter(storeOp.getContext()); + + auto resultType = storeOp.value().getType(); + LLVM::LLVMPointerType llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType); + ; + Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(), + adaptor.index(), rewriter); + Value bitCastedPtr = rewriter.create( + storeOp.getLoc(), llvmDataTypePtr, dataPtr); + Value vl = storeOp.getOperand(3); + Value vlCast = rewriter + .create( + storeOp.getLoc(), rewriter.getI64Type(), vl) + .getResult(0); + rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), + bitCastedPtr, vlCast); + return success(); + } +}; + +using RISCVVSetVlOpLowering = + OneToOneConvertToLLVMPattern; +using RISCVVAddOpLowering = + ConvertPassthruOperandOpToLLVMPattern; +using RISCVVSubOpLowering = + ConvertPassthruOperandOpToLLVMPattern; +using RISCVVMulOpLowering = + ConvertPassthruOperandOpToLLVMPattern; +using RISCVVDivOpLowering = + ConvertPassthruOperandOpToLLVMPattern; +using RISCVVMaskedAddOpLowering = + ConvertPolicyOperandOpToLLVMPattern; +using RISCVVMaskedSubOpLowering = + ConvertPolicyOperandOpToLLVMPattern; +using RISCVVMaskedMulOpLowering = + ConvertPolicyOperandOpToLLVMPattern; +using RISCVVMaskedDivOpLowering = + ConvertPolicyOperandOpToLLVMPattern; + +/// Populate the given list with patterns that convert from RISCVV to LLVM. +void mlir::populateRISCVVLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // clang-format off + patterns.add, + ForwardOperands, + ForwardOperands + >(converter, &converter.getContext()); + patterns.add(converter); + patterns.add(converter); + patterns.add(converter); + // clang-format on +} + +void mlir::configureRISCVVLegalizeForExportTarget( + LLVMConversionTarget &target) { + // clang-format off + target.addLegalOp(); + target.addIllegalOp(); + // clang-format on +} diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -40,6 +40,7 @@ MLIRArmNeonToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation + MLIRRISCVVToLLVMIRTranslation MLIRX86VectorToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(RISCVV) add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/RISCVV/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRRISCVVToLLVMIRTranslation + RISCVVToLLVMIRTranslation.cpp + + DEPENDS + MLIRRISCVVConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRRISCVV + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.cpp @@ -0,0 +1,56 @@ +//======- RISCVVToLLVMIRTranslation.cpp - Translate RISCVV to LLVM IR ----====// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the RISCVV dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/RISCVV/RISCVVToLLVMIRTranslation.h" +#include "mlir/Dialect/RISCVV/RISCVVDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsRISCV.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the RISCVV dialect to LLVM IR. +class RISCVVDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/RISCVV/RISCVVConversions.inc" + + return failure(); + } +}; +} // end namespace + +void mlir::registerRISCVVDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, riscvv::RISCVVDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerRISCVVDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerRISCVVDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -35,6 +35,18 @@ option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.") option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.") option(MLIR_RUN_ARM_SVE_TESTS "Run Arm SVE tests.") + set(RISCV_VECTOR_QEMU_EXECUTABLE "" CACHE STRING + "If set, arch-specific integration tests are run with RISC-V QEMU emulator.") + set(RISCV_VECTOR_QEMU_OPTIONS "" CACHE STRING + "If arch-specific integration tests run emulated, pass these as parameters to the emulator.") + set(RISCV_QEMU_LLI_EXECUTABLE "" CACHE STRING + "If arch-specific integration tests run emulated, use this RISC-V native lli.") + set(RISCV_QEMU_UTILS_LIB_DIR "" CACHE STRING + "If arch-specific integration tests run emulated, find RISC-V native utility libraries in this directory.") + option(MLIR_RUN_AMX_TESTS "Run AMX tests.") + option(MLIR_RUN_X86VECTOR_TESTS "Run X86Vector tests.") + option(MLIR_RUN_CUDA_TENSOR_CORE_TESTS "Run CUDA Tensor core WMMA tests.") + option(MLIR_RUN_RISCV_VECTOR_TESTS "Run RISC-V vector tests.") # Passed to lit.site.cfg.py.in to set up the path where to find the libraries. set(MLIR_INTEGRATION_TEST_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/mlir/test/Dialect/RISCVV/legalize-for-llvm.mlir b/mlir/test/Dialect/RISCVV/legalize-for-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/RISCVV/legalize-for-llvm.mlir @@ -0,0 +1,126 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-riscvv" | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @riscvv_setvl +func @riscvv_setvl(%avl: index, %sew: index, %lmul: index) -> index { + // CHECK: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK-NEXT: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK-NEXT: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK-NEXT: riscvv.intr.vsetvli{{.*}} : (i64, i64, i64) -> i64 + // CHECK-NEXT: builtin.unrealized_conversion_cast{{.*}} : i64 to index + %vl = riscvv.setvl %avl, %sew, %lmul : index + return %vl : index +} + +// CHECK-LABEL: func @riscvv_memory +func @riscvv_memory(%v: vector<[8]xi32>, + %m: memref, + %vl: index) -> vector<[8]xi32> { + %c0 = arith.constant 0 : index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK-NEXT: llvm.extractvalue{{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: llvm.getelementptr{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: llvm.bitcast{{.*}} : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK-NEXT: riscvv.intr.vle{{.*}} : (vector<[8]xi32>, !llvm.ptr>, i64) -> vector<[8]xi32> + %0 = riscvv.load %m[%c0], %vl : memref, vector<[8]xi32>, index + // CHECK: llvm.extractvalue {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: llvm.getelementptr{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: llvm.bitcast{{.*}} : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK-NEXT: riscvv.intr.vse{{.*}} : (vector<[8]xi32>, !llvm.ptr>, i64) -> () + riscvv.store %v, %m[%c0], %vl : vector<[8]xi32>, memref, index + return %0 : vector<[8]xi32> +} + +// CHECK-LABEL: func @riscvv_arith +func @riscvv_arith(%a: vector<[8]xi32>, + %b: vector<[8]xi32>, + %c: i32, + %vl: index) -> vector<[8]xi32> { + // CHECK: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vadd{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, i64) -> vector<[8]xi32> + %0 = riscvv.add %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vadd{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, i64) -> vector<[8]xi32> + %1 = riscvv.add %a, %c, %vl : vector<[8]xi32>, i32, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vsub{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, i64) -> vector<[8]xi32> + %2 = riscvv.sub %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vsub{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, i64) -> vector<[8]xi32> + %3 = riscvv.sub %a, %c, %vl : vector<[8]xi32>, i32, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vmul{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, i64) -> vector<[8]xi32> + %4 = riscvv.mul %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vmul{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, i64) -> vector<[8]xi32> + %5 = riscvv.mul %a, %c, %vl : vector<[8]xi32>, i32, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vdiv{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, i64) -> vector<[8]xi32> + %6 = riscvv.div %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: llvm.mlir.undef : vector<[8]xi32> + // CHECK: riscvv.intr.vdiv{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, i64) -> vector<[8]xi32> + %7 = riscvv.div %a, %c, %vl : vector<[8]xi32>, i32, index + return %7 : vector<[8]xi32> +} + +// CHECK-LABEL: func @riscvv_masked_arith +func @riscvv_masked_arith(%maskedoff: vector<[8]xi32>, + %a: vector<[8]xi32>, + %b: vector<[8]xi32>, + %c: i32, + %mask: vector<[8]xi1>, + %vl: index) -> vector<[8]xi32> { + // CHECK: builtin.unrealized_conversion_cast{{.*}} : index to i64 + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vadd_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %0 = riscvv.masked.add %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vadd_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %1 = riscvv.masked.add %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vsub_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %2 = riscvv.masked.sub %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vsub_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %3 = riscvv.masked.sub %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vmul_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %4 = riscvv.masked.mul %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vmul_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %5 = riscvv.masked.mul %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vdiv_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %6 = riscvv.masked.div %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(1 : i64) : i64 + // CHECK: riscvv.intr.vdiv_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %7 = riscvv.masked.div %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vadd_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %8 = riscvv.masked.add %maskedoff, %a, %b, %mask, %vl {policy = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vadd_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %9 = riscvv.masked.add %maskedoff, %a, %c, %mask, %vl {policy = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vsub_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %10 = riscvv.masked.sub %maskedoff, %a, %b, %mask, %vl {policy = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vsub_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %11 = riscvv.masked.sub %maskedoff, %a, %c, %mask, %vl {policy = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vmul_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %12 = riscvv.masked.mul %maskedoff, %a, %b, %mask, %vl {policy = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vmul_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %13 = riscvv.masked.mul %maskedoff, %a, %c, %mask, %vl {policy = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vdiv_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %14 = riscvv.masked.div %maskedoff, %a, %b, %mask, %vl {policy = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: llvm.mlir.constant(0 : i64) : i64 + // CHECK: riscvv.intr.vdiv_mask{{.*}} : (vector<[8]xi32>, vector<[8]xi32>, i32, vector<[8]xi1>, i64, i64) -> vector<[8]xi32> + %15 = riscvv.masked.div %maskedoff, %a, %c, %mask, %vl {policy = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + + return %15 : vector<[8]xi32> +} diff --git a/mlir/test/Dialect/RISCVV/roundtrip.mlir b/mlir/test/Dialect/RISCVV/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/RISCVV/roundtrip.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @riscvv_setvl +func @riscvv_setvl(%avl: index, %sew: index, %lmul: index) -> index { + // CHECK: riscvv.setvl {{.*}}: index + %vl = riscvv.setvl %avl, %sew, %lmul : index + return %vl : index +} + +// CHECK-LABEL: func @riscvv_memory +func @riscvv_memory(%v: vector<[8]xi32>, + %m: memref, + %vl: index) -> vector<[8]xi32> { + %c0 = arith.constant 0 : index + // CHECK: riscvv.load {{.*}}: memref, vector<[8]xi32>, index + %0 = riscvv.load %m[%c0], %vl : memref, vector<[8]xi32>, index + // CHECK: riscvv.store {{.*}}: vector<[8]xi32>, memref, index + riscvv.store %v, %m[%c0], %vl : vector<[8]xi32>, memref, index + return %0 : vector<[8]xi32> +} + +// CHECK-LABEL: func @riscvv_arith +func @riscvv_arith(%a: vector<[8]xi32>, + %b: vector<[8]xi32>, + %c: i32, + %vl: index) -> vector<[8]xi32> { + // CHECK: riscvv.add {{.*}} : vector<[8]xi32>, vector<[8]xi32>, index + %0 = riscvv.add %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: riscvv.add {{.*}} : vector<[8]xi32>, i32, index + %1 = riscvv.add %a, %c, %vl : vector<[8]xi32>, i32, index + // CHECK: riscvv.sub {{.*}} : vector<[8]xi32>, vector<[8]xi32>, index + %2 = riscvv.sub %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: riscvv.sub {{.*}} : vector<[8]xi32>, i32, index + %3 = riscvv.sub %a, %c, %vl : vector<[8]xi32>, i32, index + // CHECK: riscvv.mul {{.*}} : vector<[8]xi32>, vector<[8]xi32>, index + %4 = riscvv.mul %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: riscvv.mul {{.*}} : vector<[8]xi32>, i32, index + %5 = riscvv.mul %a, %c, %vl : vector<[8]xi32>, i32, index + // CHECK: riscvv.div {{.*}} : vector<[8]xi32>, vector<[8]xi32>, index + %6 = riscvv.div %a, %b, %vl : vector<[8]xi32>, vector<[8]xi32>, index + // CHECK: riscvv.div {{.*}} : vector<[8]xi32>, i32, index + %7 = riscvv.div %a, %c, %vl : vector<[8]xi32>, i32, index + return %7 : vector<[8]xi32> +} + +// CHECK-LABEL: func @riscvv_masked_arith +func @riscvv_masked_arith(%maskedoff: vector<[8]xi32>, + %a: vector<[8]xi32>, + %b: vector<[8]xi32>, + %c: i32, + %mask: vector<[8]xi1>, + %vl: index) -> vector<[8]xi32> { + // CHECK: riscvv.masked.add {{.*}} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %0 = riscvv.masked.add %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.add {{.*}} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %1 = riscvv.masked.add %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.sub {{.*}} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %2 = riscvv.masked.sub %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.sub {{.*}} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %3 = riscvv.masked.sub %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.mul {{.*}} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %4 = riscvv.masked.mul %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.mul {{.*}} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %5 = riscvv.masked.mul %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.div {{.*}} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %6 = riscvv.masked.div %maskedoff, %a, %b, %mask, %vl : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.div {{.*}} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %7 = riscvv.masked.div %maskedoff, %a, %c, %mask, %vl : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.add {{.*}} {vta = 0 : i64} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %8 = riscvv.masked.add %maskedoff, %a, %b, %mask, %vl {vta = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.add {{.*}} {vta = 0 : i64} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %9 = riscvv.masked.add %maskedoff, %a, %c, %mask, %vl {vta = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.sub {{.*}} {vta = 0 : i64} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %10 = riscvv.masked.sub %maskedoff, %a, %b, %mask, %vl {vta = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.sub {{.*}} {vta = 0 : i64} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %11 = riscvv.masked.sub %maskedoff, %a, %c, %mask, %vl {vta = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.mul {{.*}} {vta = 0 : i64} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %12 = riscvv.masked.mul %maskedoff, %a, %b, %mask, %vl {vta = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.mul {{.*}} {vta = 0 : i64} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %13 = riscvv.masked.mul %maskedoff, %a, %c, %mask, %vl {vta = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + // CHECK: riscvv.masked.div {{.*}} {vta = 0 : i64} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + %14 = riscvv.masked.div %maskedoff, %a, %b, %mask, %vl {vta = 0} : vector<[8]xi32>, vector<[8]xi32>, vector<[8]xi1>, index + // CHECK: riscvv.masked.div {{.*}} {vta = 0 : i64} : vector<[8]xi32>, i32, vector<[8]xi1>, index + %15 = riscvv.masked.div %maskedoff, %a, %c, %mask, %vl {vta = 0} : vector<[8]xi32>, i32, vector<[8]xi1>, index + return %15 : vector<[8]xi32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/lit.local.cfg @@ -0,0 +1,28 @@ +import sys + +# RISC-V vector tests must be enabled via build flag. +if config.mlir_run_riscv_vector_tests != 'ON': + config.unsupported = True + +# No JIT on win32. +if sys.platform == 'win32': + config.unsupported = True + +lli_cmd = 'lli' +if config.riscv_qemu_lli_executable: + lli_cmd = config.riscv_qemu_lli_executable + +if config.riscv_qemu_utils_lib_dir: + config.substitutions.append(('%mlir_native_utils_lib_dir', config.riscv_qemu_utils_lib_dir)) +else: + config.substitutions.append(('%mlir_native_utils_lib_dir', config.mlir_integration_test_dir)) + +if config.riscv_vector_qemu_executable: + # Run test in qemu emulator. + emulation_cmd = config.riscv_vector_qemu_executable + if config.riscv_vector_qemu_options: + emulation_cmd = emulation_cmd + ' ' + config.riscv_vector_qemu_options + emulation_cmd = emulation_cmd + ' ' + lli_cmd + config.substitutions.append(('%lli', emulation_cmd)) +else: + config.substitutions.append(('%lli', lli_cmd)) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-arithmetic.mlir b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-arithmetic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-arithmetic.mlir @@ -0,0 +1,126 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-riscvv" -convert-scf-to-cf -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --march=riscv64 -mattr=+v -jit-linker=jitlink -relocation-model=pic --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +memref.global "private" @gv1 : memref<8xi32> = dense<[2, 4, 6, 8, 10, 12, 14, 16]> +memref.global "private" @gv2 : memref<8xi32> = dense<[1, 2, 3, 4, 5, 6, 7, 8]> +memref.global "private" @maskedoff : memref<8xi32> = dense<[0, 0, 0, 0, 0, 0, 0, 0]> + +func @entry() -> i32 { + %c0_idx = arith.constant 0 : index + %c4_idx = arith.constant 4 : index + %c2_i32 = arith.constant 2 : i32 + + %m1 = memref.get_global @gv1 : memref<8xi32> + %m2 = memref.get_global @gv2 : memref<8xi32> + %m_maskedoff = memref.get_global @maskedoff : memref<8xi32> + + %test = memref.alloc() : memref<8xi32> + %vl = memref.dim %m1, %c0_idx : memref<8xi32> + + %input_vector1 = riscvv.load %m1[%c0_idx], %vl : memref<8xi32>, vector<[4]xi32>, index + %input_vector2 = riscvv.load %m2[%c0_idx], %vl : memref<8xi32>, vector<[4]xi32>, index + %maskedoff_vector = riscvv.load %m_maskedoff[%c0_idx], %vl : memref<8xi32>, vector<[4]xi32>, index + %mask_vector = vector.create_mask %c4_idx : vector<[4]xi1> + + %add_vv_result_vector = riscvv.add %input_vector1, %input_vector2, %vl : vector<[4]xi32>, vector<[4]xi32>, index + riscvv.store %add_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_add_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 3, 6, 9, 12, 15, 18, 21, 24 ) + vector.print %load_add_vv_result_vector : vector<8xi32> + + %add_vx_result_vector = riscvv.add %input_vector1, %c2_i32, %vl : vector<[4]xi32>, i32, index + riscvv.store %add_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_add_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 4, 6, 8, 10, 12, 14, 16, 18 ) + vector.print %load_add_vx_result_vector : vector<8xi32> + + %masked_add_vv_result_vector = riscvv.masked.add %maskedoff_vector, %input_vector1, %input_vector2, %mask_vector, %vl : vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi1>, index + riscvv.store %masked_add_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_add_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 3, 6, 9, 12, 0, 0, 0, 0 ) + vector.print %load_masked_add_vv_result_vector : vector<8xi32> + + %masked_add_vx_result_vector = riscvv.masked.add %maskedoff_vector, %input_vector1, %c2_i32, %mask_vector, %vl : vector<[4]xi32>, i32, vector<[4]xi1>, index + riscvv.store %masked_add_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_add_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 4, 6, 8, 10, 0, 0, 0, 0 ) + vector.print %load_masked_add_vx_result_vector : vector<8xi32> + + %sub_vv_result_vector = riscvv.sub %input_vector1, %input_vector2, %vl : vector<[4]xi32>, vector<[4]xi32>, index + riscvv.store %sub_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_sub_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 ) + vector.print %load_sub_vv_result_vector : vector<8xi32> + + %sub_vx_result_vector = riscvv.sub %input_vector1, %c2_i32, %vl : vector<[4]xi32>, i32, index + riscvv.store %sub_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_sub_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 0, 2, 4, 6, 8, 10, 12, 14 ) + vector.print %load_sub_vx_result_vector : vector<8xi32> + + %masked_sub_vv_result_vector = riscvv.masked.sub %maskedoff_vector, %input_vector1, %input_vector2, %mask_vector, %vl : vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi1>, index + riscvv.store %masked_sub_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_sub_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 1, 2, 3, 4, 0, 0, 0, 0 ) + vector.print %load_masked_sub_vv_result_vector : vector<8xi32> + + %masked_sub_vx_result_vector = riscvv.masked.sub %maskedoff_vector, %input_vector1, %c2_i32, %mask_vector, %vl : vector<[4]xi32>, i32, vector<[4]xi1>, index + riscvv.store %masked_sub_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_sub_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 0, 2, 4, 6, 0, 0, 0, 0 ) + vector.print %load_masked_sub_vx_result_vector : vector<8xi32> + + %mul_vv_result_vector = riscvv.mul %input_vector1, %input_vector2, %vl : vector<[4]xi32>, vector<[4]xi32>, index + riscvv.store %mul_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_mul_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 2, 8, 18, 32, 50, 72, 98, 128 ) + vector.print %load_mul_vv_result_vector : vector<8xi32> + + %mul_vx_result_vector = riscvv.mul %input_vector1, %c2_i32, %vl : vector<[4]xi32>, i32, index + riscvv.store %mul_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_mul_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 4, 8, 12, 16, 20, 24, 28, 32 ) + vector.print %load_mul_vx_result_vector : vector<8xi32> + + %masked_mul_vv_result_vector = riscvv.masked.mul %maskedoff_vector, %input_vector1, %input_vector2, %mask_vector, %vl : vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi1>, index + riscvv.store %masked_mul_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_mul_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 2, 8, 18, 32, 0, 0, 0, 0 ) + vector.print %load_masked_mul_vv_result_vector : vector<8xi32> + + %masked_mul_vx_result_vector = riscvv.masked.mul %maskedoff_vector, %input_vector1, %c2_i32, %mask_vector, %vl : vector<[4]xi32>, i32, vector<[4]xi1>, index + riscvv.store %masked_mul_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_mul_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 4, 8, 12, 16, 0, 0, 0, 0 ) + vector.print %load_masked_mul_vx_result_vector : vector<8xi32> + + %div_vv_result_vector = riscvv.div %input_vector1, %input_vector2, %vl : vector<[4]xi32>, vector<[4]xi32>, index + riscvv.store %div_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_div_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 2, 2, 2, 2, 2, 2, 2, 2 ) + vector.print %load_div_vv_result_vector : vector<8xi32> + + %div_vx_result_vector = riscvv.div %input_vector1, %c2_i32, %vl : vector<[4]xi32>, i32, index + riscvv.store %div_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_div_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 1, 2, 3, 4, 5, 6, 7, 8 ) + vector.print %load_div_vx_result_vector : vector<8xi32> + + %masked_div_vv_result_vector = riscvv.masked.div %maskedoff_vector, %input_vector1, %input_vector2, %mask_vector, %vl : vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi1>, index + riscvv.store %masked_div_vv_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_div_vv_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 2, 2, 2, 2, 0, 0, 0, 0 ) + vector.print %load_masked_div_vv_result_vector : vector<8xi32> + + %masked_div_vx_result_vector = riscvv.masked.div %maskedoff_vector, %input_vector1, %c2_i32, %mask_vector, %vl : vector<[4]xi32>, i32, vector<[4]xi1>, index + riscvv.store %masked_div_vx_result_vector, %test[%c0_idx], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_masked_div_vx_result_vector = vector.load %test[%c0_idx] : memref<8xi32>, vector<8xi32> + // CHECK: ( 1, 2, 3, 4, 0, 0, 0, 0 ) + vector.print %load_masked_div_vx_result_vector : vector<8xi32> + + memref.dealloc %test : memref<8xi32> + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-memory.mlir b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-memory.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-memory.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-riscvv" --convert-memref-to-llvm --convert-func-to-llvm --reconcile-unrealized-casts | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --march=riscv64 -mattr=+v -jit-linker=jitlink -relocation-model=pic --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +memref.global "private" @gv : memref<8xi32> = dense<[0, 1, 2, 3, 4, 5, 6, 7]> + +func @entry() -> i32 { + %mem = memref.get_global @gv : memref<8xi32> + %c0 = arith.constant 0 : index + %vl = arith.constant 8 : index + %input1 = riscvv.load %mem[%c0], %vl : memref<8xi32>, vector<[4]xi32>, index + %res = memref.alloc() : memref<8xi32> + riscvv.store %input1, %res[%c0], %vl : vector<[4]xi32>, memref<8xi32>, index + %load_vec1 = vector.load %res[%c0] : memref<8xi32>, vector<8xi32> + // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7 ) + vector.print %load_vec1 : vector<8xi32> + memref.dealloc %res : memref<8xi32> + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-stripmining.mlir b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-stripmining.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/RISCVV/test-riscvv-stripmining.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-riscvv" -convert-scf-to-cf -convert-memref-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --march=riscv64 -mattr=+v -jit-linker=jitlink -relocation-model=pic --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +memref.global "private" @gv : memref<20xi32> = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]> + +func @entry() -> i32 { + %m = memref.get_global @gv : memref<20xi32> + %c0_idx = arith.constant 0 : index + %init_avl = memref.dim %m, %c0_idx : memref<20xi32> + %init_idx = arith.constant 0 : index + %c2_i32 = arith.constant 2 : i32 + // SEW = 32 + %sew = arith.constant 2 : index + // LMUL = 2 + %lmul = arith.constant 1 : index + %res = memref.alloc() : memref<20xi32> + // While loop. + %a1, %a2 = scf.while (%avl = %init_avl, %idx = %init_idx) : (index, index) -> (index, index) { + // If avl greater than zero. + %cond = arith.cmpi sgt, %avl, %c0_idx : index + // Pass avl, idx to the after region. + scf.condition(%cond) %avl, %idx : index, index + } do { + ^bb0(%avl : index, %idx : index): + // Perform the calculation according to the vl. + %vl = riscvv.setvl %avl, %sew, %lmul : index + %input_vector = riscvv.load %m[%idx], %vl : memref<20xi32>, vector<[4]xi32>, index + %result_vector = riscvv.add %input_vector, %c2_i32, %vl : vector<[4]xi32>, i32, index + riscvv.store %result_vector, %res[%idx], %vl : vector<[4]xi32>, memref<20xi32>, index + // Update idx and avl. + %new_idx = arith.addi %idx, %vl : index + %new_avl = arith.subi %avl, %vl : index + scf.yield %new_avl, %new_idx : index, index + } + %result = vector.load %res[%c0_idx] : memref<20xi32>, vector<20xi32> + // CHECK: ( 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 ) + vector.print %result : vector<20xi32> + %ret = arith.constant 0 : i32 + return %ret : i32 +} diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -59,6 +59,11 @@ config.arm_emulator_lli_executable = "@ARM_EMULATOR_LLI_EXECUTABLE@" config.arm_emulator_utils_lib_dir = "@ARM_EMULATOR_UTILS_LIB_DIR@" config.mlir_run_arm_sve_tests = "@MLIR_RUN_ARM_SVE_TESTS@" +config.riscv_vector_qemu_executable = "@RISCV_VECTOR_QEMU_EXECUTABLE@" +config.riscv_vector_qemu_options = "@RISCV_VECTOR_QEMU_OPTIONS@" +config.riscv_qemu_lli_executable = "@RISCV_QEMU_LLI_EXECUTABLE@" +config.riscv_qemu_utils_lib_dir = "@RISCV_QEMU_UTILS_LIB_DIR@" +config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@" import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -24,6 +24,7 @@ // CHECK-NEXT: pdl // CHECK-NEXT: pdl_interp // CHECK-NEXT: quant +// CHECK-NEXT: riscvv // CHECK-NEXT: rocdl // CHECK-NEXT: scf // CHECK-NEXT: shape