diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -620,7 +620,7 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the + (AMX, X86Vector, ArmNeon, ArmSVE, RVV, etc.) in combination with the architectural-neutral vector dialect lowering. }]; @@ -650,7 +650,10 @@ Option<"enableX86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " - "dialect."> + "dialect.">, + Option<"enableRVV", "enable-rvv", + "bool", /*default=*/"false", + "Enables the use of RVV dialect while lowering the vector dialect."> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -24,7 +24,7 @@ LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), enableArmNeon(false), enableArmSVE(false), enableAMX(false), - enableX86Vector(false) {} + enableX86Vector(false), enableRVV(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -50,6 +50,10 @@ enableX86Vector = b; return *this; } + LowerVectorToLLVMOptions &setEnableRVV(bool b) { + enableRVV = b; + return *this; + } bool reassociateFPReductions; bool enableIndexOptimizations; @@ -57,6 +61,7 @@ bool enableArmSVE; bool enableAMX; bool enableX86Vector; + bool enableRVV; }; /// Collect a set of patterns to convert from Vector contractions to LLVM Matrix diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Quant) +add_subdirectory(RVV) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SparseTensor) diff --git a/mlir/include/mlir/Dialect/RVV/CMakeLists.txt b/mlir/include/mlir/Dialect/RVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RVV/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(RVV rvv) +add_mlir_doc(RVV RVV Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS RVV.td) +mlir_tablegen(RVVConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRRVVConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/RVV/RVV.td b/mlir/include/mlir/Dialect/RVV/RVV.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RVV/RVV.td @@ -0,0 +1,271 @@ +//===-- RVV.td - RVV dialect operation definitions ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the RVV dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef RVV_OPS +#define RVV_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// RVV dialect definition +//===----------------------------------------------------------------------===// + +def RVV_Dialect : Dialect { + let name = "rvv"; + let cppNamespace = "::mlir::rvv"; + let summary = "Basic dialect to target RISC-V Vector extension"; + let description = [{ + RISC-V vector extension (RVV) is the vector instruction set with scalable + vector types, and the RVV instructions are vector length agnostic (VLA). + For more details about RVV, pleas see the + [RVV specification](https://github.com/riscv/riscv-v-spec). + This dialect contains the definitions of RVV operations and RVV intrinsic + operations. The former is used to interoperate with higher-level dialects, + and the latter is responsible for mapping to LLVM IR intrinsic. + }]; +} + +//===----------------------------------------------------------------------===// +// RVV type definitions +//===----------------------------------------------------------------------===// + +class RVV_Type : TypeDef { } + +def RVVScalableVectorType : RVV_Type<"RVVScalableVector"> { + let mnemonic = "vector"; + + let summary = "RVV Scalable vector type"; + + let description = [{ + RVV scalable vector type shares the same syntax with standard vector type + but have different semantics. It can reflect LMUL and SEW, which are used + to configure scalable vector length at runtime. + }]; + + let parameters = (ins + ArrayRefParameter<"int64_t", "Vector shape">:$shape, + "Type":$elementType + ); + + let printer = [{ + $_printer << "vector<"; + for (int64_t dim : getShape()) + $_printer << dim << 'x'; + $_printer << getElementType() << '>'; + }]; + + let parser = [{ + VectorType vector; + if ($_parser.parseType(vector)) + return Type(); + return get($_ctxt, vector.getShape(), vector.getElementType()); + }]; + + let extraClassDeclaration = [{ + bool hasStaticShape() const { + return llvm::none_of(getShape(), ShapedType::isDynamic); + } + int64_t getNumElements() const { + assert(hasStaticShape() && + "cannot get element count of dynamic shaped type"); + ArrayRef shape = getShape(); + int64_t num = 1; + for (auto dim : shape) + num *= dim; + return num; + } + }]; +} + +//===----------------------------------------------------------------------===// +// Additional LLVM type constraints +//===----------------------------------------------------------------------===// + +def LLVMScalableVectorType : + Type()">, + "LLVM dialect scalable vector type">; + +def LLVMPointerType : + Type()">, + "LLVM pointer type">; + +//===----------------------------------------------------------------------===// +// RVV scalable vector type constraints +//===----------------------------------------------------------------------===// + +def IsRVVScalableVectorTypePred : + CPred<"$_self.isa<::mlir::rvv::RVVScalableVectorType>()">; + +class RVVScalableVectorOf allowedTypes> : + ContainerType, IsRVVScalableVectorTypePred, + "$_self.cast<::mlir::rvv::RVVScalableVectorType>().getElementType()", + "RVV scalable vector">; + +//===----------------------------------------------------------------------===// +// RVV operation definitions +//===----------------------------------------------------------------------===// + +class RVV_Op traits = []> : + Op {} + +def RVVLoadOp : RVV_Op<"load">, + Arguments<(ins Arg:$base, Index:$index, + AnyInteger:$length)>, + Results<(outs RVVScalableVectorOf<[AnyType]>:$result)> { + let summary = "Load scalable vector from memory"; + let description = [{ + Load a slice of memory into scalable vector with the given element length. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $index `]` `,` $length attr-dict `:` " + "type($base) `,` type($result) `,` type($length)"; +} + +def RVVStoreOp : RVV_Op<"store">, + Arguments<(ins RVVScalableVectorOf<[AnyType]>:$value, Arg:$base, Index:$index, AnyInteger:$length)> { + let summary = "Store scalable vector into memory"; + let description = [{ + Store the given element length of a scalable vector on a slice of memory. + }]; + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + }]; + let assemblyFormat = "$value `,` $base `[` $index `]` `,` $length attr-dict " + "`:` type($value) `,` type($base) `,` type($length)"; +} + +class RVV_BinaryAAXNoMask_Op traits = []> : + RVV_Op])> { + let summary = op_description # "for RVV scalable vectors"; + let description = [{ The `rvv.}] # mnemonic # [{` operation can be of + vector-vector form or vector-scalar form. It also takes the vector length + value and returns a scalable vector with the result of the }] + # op_description # [{.}]; + let arguments = (ins + RVVScalableVectorOf<[AnyInteger]>:$src1, + AnyType:$src2, + AnyInteger:$length + ); + let results = (outs RVVScalableVectorOf<[AnyInteger]>:$dst); + let assemblyFormat = "$src1 `,` $src2 `,` $length attr-dict `:` type($src1) " + "`,` type($src2) `,` type($length)"; +} + +class RVV_BinaryAAXMask_Op traits = []> : + RVV_Op])> { + let summary = op_description # "for RVV scalable vectors"; + let description = [{ The `rvv.}] # mnemonic # [{` operation can be of + vector-vector form or vector-scalar form. It also takes the mask vector, + maskedoff vector, vector length value and returns a scalable vector with + the result of the }] # op_description # [{.}]; + let arguments = (ins + RVVScalableVectorOf<[AnyInteger]>:$maskedoff, + RVVScalableVectorOf<[AnyInteger]>:$src1, + AnyType:$src2, + RVVScalableVectorOf<[I1]>:$mask, + AnyInteger:$length + ); + let results = (outs RVVScalableVectorOf<[AnyInteger]>:$dst); + let assemblyFormat = + "$maskedoff `,` $src1 `,` $src2 `,` $mask `,` $length attr-dict `:` " + "type($maskedoff) `,` type($src2) `,` type($mask) `,` type($length)"; +} + +def RVVAddOp : RVV_BinaryAAXNoMask_Op<"add", "addition">; +def RVVSubOp : RVV_BinaryAAXNoMask_Op<"sub", "subtraction">; +def RVVMulOp : RVV_BinaryAAXNoMask_Op<"mul", "multiplication">; +def RVVDivOp : RVV_BinaryAAXNoMask_Op<"div", "division">; + +def RVVMaskedAddOp : RVV_BinaryAAXMask_Op<"masked.add", "masked addition">; +def RVVMaskedSubOp : RVV_BinaryAAXMask_Op<"masked.sub", "masked subtraction">; +def RVVMaskedMulOp : RVV_BinaryAAXMask_Op<"masked.mul", + "masked multiplication">; +def RVVMaskedDivOp : RVV_BinaryAAXMask_Op<"masked.div", "masked division">; + +//===----------------------------------------------------------------------===// +// RVV intrinsic operation definitions +//===----------------------------------------------------------------------===// + +class RVV_USLoad_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[1], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RVV_USStore_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[], + /*list overloadedOperands=*/[0, 2], + /*list traits=*/traits, + /*int numResults=*/0>; + +class RVV_BinaryAAXNoMask_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[1, 2], + /*list traits=*/traits, + /*int numResults=*/1>; + +class RVV_BinaryAAXMask_IntrOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[2, 4], + /*list traits=*/traits, + /*int numResults=*/1>; + +def RVVIntrLoadEleOp : RVV_USLoad_IntrOp<"vle">, + Arguments<(ins LLVMPointerType, AnyInteger)>; +def RVVIntrStoreEleOp : RVV_USStore_IntrOp<"vse">, + Arguments<(ins LLVMScalableVectorType, LLVMPointerType, AnyInteger)>; + +def RVVIntrAddOp : RVV_BinaryAAXNoMask_IntrOp<"vadd">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; +def RVVIntrSubOp : RVV_BinaryAAXNoMask_IntrOp<"vsub">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; +def RVVIntrMulOp : RVV_BinaryAAXNoMask_IntrOp<"vmul">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; +def RVVIntrDivOp : RVV_BinaryAAXNoMask_IntrOp<"vdiv">, + Arguments<(ins LLVMScalableVectorType, AnyType, AnyInteger)>; + +def RVVMaskedIntrAddOp : RVV_BinaryAAXMask_IntrOp<"vadd_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; +def RVVMaskedIntrSubOp : RVV_BinaryAAXMask_IntrOp<"vsub_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; +def RVVMaskedIntrMulOp : RVV_BinaryAAXMask_IntrOp<"vmul_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; +def RVVMaskedIntrDivOp : RVV_BinaryAAXMask_IntrOp<"vdiv_mask">, + Arguments<(ins LLVMScalableVectorType, LLVMScalableVectorType, AnyType, + LLVMScalableVectorType, AnyInteger)>; + +#endif // RVV_OPS diff --git a/mlir/include/mlir/Dialect/RVV/RVVDialect.h b/mlir/include/mlir/Dialect/RVV/RVVDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RVV/RVVDialect.h @@ -0,0 +1,29 @@ +//===- RVVDialect.h - MLIR Dialect for RISC-V vector extension --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for RVV in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_RVV_RVVDIALECT_H +#define MLIR_DIALECT_RVV_RVVDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/RVV/RVVDialect.h.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/RVV/RVVTypes.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/RVV/RVV.h.inc" + +#endif // MLIR_DIALECT_RVV_RVVDIALECT_H diff --git a/mlir/include/mlir/Dialect/RVV/Transforms.h b/mlir/include/mlir/Dialect/RVV/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/RVV/Transforms.h @@ -0,0 +1,30 @@ +//===- Transforms.h - RVV Dialect Transformation Entrypoints ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_RVV_TRANSFORMS_H +#define MLIR_DIALECT_RVV_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; +using OwningRewritePatternList = RewritePatternSet; + +/// Collect a set of patterns to lower RVV ops to ops that map to LLVM +/// intrinsics. +void populateRVVLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering RVV ops to ops that map to LLVM +/// intrinsics. +void configureRVVLegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_RVV_TRANSFORMS_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -34,6 +34,7 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/RVV/RVVDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -68,6 +69,7 @@ pdl::PDLDialect, pdl_interp::PDLInterpDialect, quant::QuantizationDialect, + rvv::RVVDialect, spirv::SPIRVDialect, StandardOpsDialect, arm_sve::ArmSVEDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -22,6 +22,7 @@ #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/X86Vector/X86VectorToLLVMIRTranslation.h" namespace mlir { @@ -38,6 +39,7 @@ registerOpenACCDialectTranslation(registry); registerOpenMPDialectTranslation(registry); registerROCDLDialectTranslation(registry); + registerRVVDialectTranslation(registry); registerX86VectorDialectTranslation(registry); } } // namespace mlir diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- RVVToLLVMIRTranslation.h - RVV to LLVM IR --------*- C++ -*-=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for RVV dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_RVV_RVVTOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_RVV_RVVTOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the RVV dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerRVVDialectTranslation(DialectRegistry ®istry); + +/// Register the RVV dialect and the translation from it in the registry +/// associated with the given context. +void registerRVVDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_RVV_RVVTOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -21,6 +21,8 @@ MLIRLLVMCommonConversion MLIRLLVMIR MLIRMemRef + MLIRRVV + MLIRRVVTransforms MLIRTargetLLVMIRExport MLIRTransforms MLIRVector diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -19,6 +19,8 @@ #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/RVV/RVVDialect.h" +#include "mlir/Dialect/RVV/Transforms.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/X86Vector/Transforms.h" @@ -37,6 +39,7 @@ this->enableArmNeon = options.enableArmNeon; this->enableArmSVE = options.enableArmSVE; this->enableAMX = options.enableAMX; + this->enableRVV = options.enableRVV; this->enableX86Vector = options.enableX86Vector; } // Override explicitly to allow conditional dialect dependence. @@ -49,6 +52,8 @@ registry.insert(); if (enableAMX) registry.insert(); + if (enableRVV) + registry.insert(); if (enableX86Vector) registry.insert(); } @@ -102,6 +107,10 @@ configureX86VectorLegalizeForExportTarget(target); populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns); } + if (enableRVV) { + configureRVVLegalizeForExportTarget(target); + populateRVVLegalizeForLLVMExportPatterns(converter, patterns); + } if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -16,6 +16,7 @@ add_subdirectory(PDL) add_subdirectory(PDLInterp) add_subdirectory(Quant) +add_subdirectory(RVV) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SparseTensor) diff --git a/mlir/lib/Dialect/RVV/CMakeLists.txt b/mlir/lib/Dialect/RVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RVV/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/RVV/IR/CMakeLists.txt b/mlir/lib/Dialect/RVV/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RVV/IR/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_dialect_library(MLIRRVV + RVVDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/RVV + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/RVV/IR/RVVDialect.cpp b/mlir/lib/Dialect/RVV/IR/RVVDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RVV/IR/RVVDialect.cpp @@ -0,0 +1,63 @@ +//===- RVVDialect.cpp - MLIR RVV dialect implementation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the RVV dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/RVV/RVVDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +#include "mlir/Dialect/RVV/RVVDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/RVV/RVV.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/RVV/RVVTypes.cpp.inc" + +void rvv::RVVDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/RVV/RVV.cpp.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/RVV/RVVTypes.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// ScalableVectorType +//===----------------------------------------------------------------------===// + +Type rvv::RVVDialect::parseType(DialectAsmParser &parser) const { + llvm::SMLoc typeLoc = parser.getCurrentLocation(); + { + Type genType; + auto parseResult = generatedTypeParser(parser.getBuilder().getContext(), + parser, "vector", genType); + if (parseResult.hasValue()) + return genType; + } + parser.emitError(typeLoc, "unknown type in RVV dialect"); + return Type(); +} + +void rvv::RVVDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'rvv' type kind"); +} diff --git a/mlir/lib/Dialect/RVV/Transforms/CMakeLists.txt b/mlir/lib/Dialect/RVV/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RVV/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRRVVTransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRRVVConversionsIncGen + + LINK_LIBS PUBLIC + MLIRRVV + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMIR + ) diff --git a/mlir/lib/Dialect/RVV/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/RVV/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/RVV/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,214 @@ +//===- LegalizeForLLVMExport.cpp - Prepare RVV for LLVM translation -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/RVV/RVVDialect.h" +#include "mlir/Dialect/RVV/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::rvv; + +// Extract an LLVM IR type from the LLVM IR dialect type. +static Type unwrap(Type type) { + if (!type) + return nullptr; + auto *mlirContext = type.getContext(); + if (!LLVM::isCompatibleType(type)) + emitError(UnknownLoc::get(mlirContext), + "conversion resulted in a non-LLVM type"); + return type; +} + +static Optional +convertScalableVectorTypeToLLVM(RVVScalableVectorType rvvSVType, + LLVMTypeConverter &converter) { + auto elementType = unwrap(converter.convertType(rvvSVType.getElementType())); + if (!elementType) + return {}; + auto sVectorType = LLVM::LLVMScalableVectorType::get( + elementType, rvvSVType.getShape().back()); + return sVectorType; +} + +template +class ForwardOperands : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpTy op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) + return rewriter.notifyMatchFailure(op, "operand types already match"); + + rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + return success(); + } +}; + +struct RVVLoadOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RVVLoadOp loadOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto type = loadOp.getMemRefType(); + if (!isConvertibleAndHasIdentityMaps(type)) + return failure(); + + RVVLoadOp::Adaptor transformed(operands); + LLVMTypeConverter converter(loadOp.getContext()); + + auto resultType = loadOp.result().getType(); + LLVM::LLVMPointerType llvmDataTypePtr; + if (resultType.isa()) { + llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType.cast()); + } else if (resultType.isa()) { + llvmDataTypePtr = LLVM::LLVMPointerType::get( + convertScalableVectorTypeToLLVM( + resultType.cast(), converter) + .getValue()); + } + Value dataPtr = + getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), + transformed.index(), rewriter); + Value bitCastedPtr = rewriter.create( + loadOp.getLoc(), llvmDataTypePtr, dataPtr); + Value vl = loadOp.getOperand(2); + rewriter.replaceOpWithNewOp( + loadOp, + convertScalableVectorTypeToLLVM( + resultType.cast(), converter) + .getValue(), + bitCastedPtr, vl); + return success(); + } +}; + +struct RVVStoreOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(RVVStoreOp storeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto type = storeOp.getMemRefType(); + if (!isConvertibleAndHasIdentityMaps(type)) + return failure(); + + RVVStoreOp::Adaptor transformed(operands); + LLVMTypeConverter converter(storeOp.getContext()); + + auto resultType = storeOp.value().getType(); + LLVM::LLVMPointerType llvmDataTypePtr; + if (resultType.isa()) { + llvmDataTypePtr = + LLVM::LLVMPointerType::get(resultType.cast()); + } else if (resultType.isa()) { + llvmDataTypePtr = LLVM::LLVMPointerType::get( + convertScalableVectorTypeToLLVM( + resultType.cast(), converter) + .getValue()); + } + Value dataPtr = + getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), + transformed.index(), rewriter); + Value bitCastedPtr = rewriter.create( + storeOp.getLoc(), llvmDataTypePtr, dataPtr); + Value vl = storeOp.getOperand(3); + rewriter.replaceOpWithNewOp(storeOp, transformed.value(), + bitCastedPtr, vl); + return success(); + } +}; + +using RVVAddOpLowering = OneToOneConvertToLLVMPattern; +using RVVSubOpLowering = OneToOneConvertToLLVMPattern; +using RVVMulOpLowering = OneToOneConvertToLLVMPattern; +using RVVDivOpLowering = OneToOneConvertToLLVMPattern; +using RVVMaskedAddOpLowering = + OneToOneConvertToLLVMPattern; +using RVVMaskedSubOpLowering = + OneToOneConvertToLLVMPattern; +using RVVMaskedMulOpLowering = + OneToOneConvertToLLVMPattern; +using RVVMaskedDivOpLowering = + OneToOneConvertToLLVMPattern; + +/// Populate the given list with patterns that convert from RVV to LLVM. +void mlir::populateRVVLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // Populate conversion patterns + // Remove any RVV-specific types from function signatures and results. + populateFuncOpTypeConversionPattern(patterns, converter); + converter.addConversion([&converter](RVVScalableVectorType rvvSVType) { + return convertScalableVectorTypeToLLVM(rvvSVType, converter); + }); + + // clang-format off + patterns.add, + ForwardOperands, + ForwardOperands>(converter, &converter.getContext()); + patterns.add(converter); + patterns.add(converter); + // clang-format on +} + +void mlir::configureRVVLegalizeForExportTarget(LLVMConversionTarget &target) { + // clang-format off + target.addLegalOp(); + target.addIllegalOp(); + // clang-format on + + auto hasScalableVectorType = [](TypeRange types) { + for (Type type : types) + if (type.isa()) + return true; + return false; + }; + target.addDynamicallyLegalOp([hasScalableVectorType](FuncOp op) { + return !hasScalableVectorType(op.getType().getInputs()) && + !hasScalableVectorType(op.getType().getResults()); + }); + target.addDynamicallyLegalOp( + [hasScalableVectorType](Operation *op) { + return !hasScalableVectorType(op->getOperandTypes()) && + !hasScalableVectorType(op->getResultTypes()); + }); +} diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -41,6 +41,7 @@ MLIRArmNeonToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation + MLIRRVVToLLVMIRTranslation MLIRX86VectorToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -6,4 +6,5 @@ add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(RVV) add_subdirectory(X86Vector) diff --git a/mlir/lib/Target/LLVMIR/Dialect/RVV/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/RVV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/RVV/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRRVVToLLVMIRTranslation + RVVToLLVMIRTranslation.cpp + + DEPENDS + MLIRRVVConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRRVV + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.cpp @@ -0,0 +1,55 @@ +//======- RVVToLLVMIRTranslation.cpp - Translate RVV to LLVM IR -------=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the RVV dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/RVV/RVVToLLVMIRTranslation.h" +#include "mlir/Dialect/RVV/RVVDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsRISCV.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the RVV dialect to LLVM IR. +class RVVDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/RVV/RVVConversions.inc" + + return failure(); + } +}; +} // end namespace + +void mlir::registerRVVDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerRVVDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerRVVDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/test/Dialect/RVV/legalize-for-llvm.mlir b/mlir/test/Dialect/RVV/legalize-for-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/RVV/legalize-for-llvm.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-rvv" -convert-std-to-llvm | mlir-opt | FileCheck %s + +func @rvv_memory(%v: !rvv.vector<4xi64>, + %m: memref, + %vl: i64) -> !rvv.vector<4xi64> { + %c0 = constant 0 : index + // CHECK: llvm.extractvalue {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: llvm.bitcast {{.*}} : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: rvv.intr.vle{{.*}} : (!llvm.ptr>, i64) -> !llvm.vec + %0 = rvv.load %m[%c0], %vl : memref, !rvv.vector<4xi64>, i64 + // CHECK: llvm.extractvalue {{.*}} : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: llvm.bitcast {{.*}} : !llvm.ptr to !llvm.ptr> + // CHECK-NEXT: rvv.intr.vse{{.*}} : (!llvm.vec, !llvm.ptr>, i64) -> () + rvv.store %v, %m[%c0], %vl : !rvv.vector<4xi64>, memref, i64 + return %0 : !rvv.vector<4xi64> +} + +func @rvv_arith(%a: !rvv.vector<4xi64>, + %b: !rvv.vector<4xi64>, + %c: i64, + %vl: i64) -> !rvv.vector<4xi64> { + // CHECK: rvv.intr.vadd{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %0 = rvv.add %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.intr.vadd{{.*}} : (!llvm.vec, i64, i64) -> !llvm.vec + %1 = rvv.add %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + // CHECK: rvv.intr.vsub{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %2 = rvv.sub %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.intr.vsub{{.*}} : (!llvm.vec, i64, i64) -> !llvm.vec + %3 = rvv.sub %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + // CHECK: rvv.intr.vmul{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %4 = rvv.mul %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.intr.vmul{{.*}} : (!llvm.vec, i64, i64) -> !llvm.vec + %5 = rvv.mul %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + // CHECK: rvv.intr.vdiv{{.*}} : (!llvm.vec, !llvm.vec, i64) -> !llvm.vec + %6 = rvv.div %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.intr.vdiv{{.*}} : (!llvm.vec, i64, i64) -> !llvm.vec + %7 = rvv.div %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + return %7 : !rvv.vector<4xi64> +} + +func @rvv_masked_arith(%maskedoff: !rvv.vector<4xi64>, + %a: !rvv.vector<4xi64>, + %b: !rvv.vector<4xi64>, + %c: i64, + %mask: !rvv.vector<4xi1>, + %vl: i64) -> !rvv.vector<4xi64> { + // CHECK: rvv.intr.vadd_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %0 = rvv.masked.add %maskedoff, %a, %b, %mask, %vl: !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vadd_mask{{.*}} : (!llvm.vec, !llvm.vec, i64, !llvm.vec, i64) -> !llvm.vec + %1 = rvv.masked.add %maskedoff, %a, %c, %mask, %vl: !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vsub_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %2 = rvv.masked.sub %maskedoff, %a, %b, %mask, %vl: !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vsub_mask{{.*}} : (!llvm.vec, !llvm.vec, i64, !llvm.vec, i64) -> !llvm.vec + %3 = rvv.masked.sub %maskedoff, %a, %c, %mask, %vl: !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vmul_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %4 = rvv.masked.mul %maskedoff, %a, %b, %mask, %vl: !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vmul_mask{{.*}} : (!llvm.vec, !llvm.vec, i64, !llvm.vec, i64) -> !llvm.vec + %5 = rvv.masked.mul %maskedoff, %a, %c, %mask, %vl: !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vdiv_mask{{.*}} : (!llvm.vec, !llvm.vec, !llvm.vec, !llvm.vec, i64) -> !llvm.vec + %6 = rvv.masked.div %maskedoff, %a, %b, %mask, %vl: !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.intr.vdiv_mask{{.*}} : (!llvm.vec, !llvm.vec, i64, !llvm.vec, i64) -> !llvm.vec + %7 = rvv.masked.div %maskedoff, %a, %c, %mask, %vl: !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + return %7 : !rvv.vector<4xi64> +} diff --git a/mlir/test/Dialect/RVV/roundtrip.mlir b/mlir/test/Dialect/RVV/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/RVV/roundtrip.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @rvv_memory +func @rvv_memory(%v: !rvv.vector<4xi64>, + %m: memref, + %vl: i64) -> !rvv.vector<4xi64> { + %c0 = constant 0 : index + // CHECK: rvv.load {{.*}}: memref, !rvv.vector<4xi64>, i64 + %0 = rvv.load %m[%c0], %vl : memref, !rvv.vector<4xi64>, i64 + // CHECK: rvv.store {{.*}}: !rvv.vector<4xi64>, memref, i64 + rvv.store %v, %m[%c0], %vl : !rvv.vector<4xi64>, memref, i64 + return %0 : !rvv.vector<4xi64> +} + +// CHECK-LABEL: func @rvv_arith +func @rvv_arith(%a: !rvv.vector<4xi64>, + %b: !rvv.vector<4xi64>, + %c: i64, + %vl: i64) -> !rvv.vector<4xi64> { + // CHECK: rvv.add {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + %0 = rvv.add %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.add {{.*}} : !rvv.vector<4xi64>, i64, i64 + %1 = rvv.add %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + // CHECK: rvv.sub {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + %2 = rvv.sub %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.sub {{.*}} : !rvv.vector<4xi64>, i64, i64 + %3 = rvv.sub %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + // CHECK: rvv.mul {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + %4 = rvv.mul %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.mul {{.*}} : !rvv.vector<4xi64>, i64, i64 + %5 = rvv.mul %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + // CHECK: rvv.div {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + %6 = rvv.div %a, %b, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, i64 + // CHECK: rvv.div {{.*}} : !rvv.vector<4xi64>, i64, i64 + %7 = rvv.div %a, %c, %vl : !rvv.vector<4xi64>, i64, i64 + return %7 : !rvv.vector<4xi64> +} + +// CHECK-LABEL: func @rvv_masked_arith +func @rvv_masked_arith(%maskedoff: !rvv.vector<4xi64>, + %a: !rvv.vector<4xi64>, + %b: !rvv.vector<4xi64>, + %c: i64, + %mask: !rvv.vector<4xi1>, + %vl: i64) -> !rvv.vector<4xi64> { + // CHECK: rvv.masked.add {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + %0 = rvv.masked.add %maskedoff, %a, %b, %mask, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.add {{.*}} : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + %1 = rvv.masked.add %maskedoff, %a, %c, %mask, %vl : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.sub {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + %2 = rvv.masked.sub %maskedoff, %a, %b, %mask, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.sub {{.*}} : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + %3 = rvv.masked.sub %maskedoff, %a, %c, %mask, %vl : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.mul {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + %4 = rvv.masked.mul %maskedoff, %a, %b, %mask, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.mul {{.*}} : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + %5 = rvv.masked.mul %maskedoff, %a, %c, %mask, %vl : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.div {{.*}} : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + %6 = rvv.masked.div %maskedoff, %a, %b, %mask, %vl : !rvv.vector<4xi64>, !rvv.vector<4xi64>, !rvv.vector<4xi1>, i64 + // CHECK: rvv.masked.div {{.*}} : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + %7 = rvv.masked.div %maskedoff, %a, %c, %mask, %vl : !rvv.vector<4xi64>, i64, !rvv.vector<4xi1>, i64 + return %7 : !rvv.vector<4xi64> +} diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -21,6 +21,7 @@ // CHECK-NEXT: pdl_interp // CHECK-NEXT: quant // CHECK-NEXT: rocdl +// CHECK-NEXT: rvv // CHECK-NEXT: scf // CHECK-NEXT: shape // CHECK-NEXT: sparse_tensor