diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td --- a/llvm/include/llvm/IR/IntrinsicsRISCV.td +++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td @@ -65,4 +65,26 @@ // @llvm.riscv.masked.cmpxchg.{i32,i64}.

(...) defm int_riscv_masked_cmpxchg : MaskedAtomicRMWFiveArgIntrinsics; + // RISC-V V extension + // For (vector, vector) and (vector, scalar) binary arithmetic. + // Input: (vector_in, vector_in/scalar_in, vl) + class RVVBinaryCommon + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, llvm_any_ty, llvm_i64_ty], + [IntrNoMem]>; + // For (vector, vector) and (vector, scalar) bianry arithmetic with mask. + // Input: (maskedoff, vector_in, vector_in/scalar_in, mask, vl) + class RVVBinaryCommonMask + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty, + llvm_anyvector_ty, llvm_i64_ty], + [IntrNoMem]>; + + multiclass riscv_binary { + def "int_riscv_" # NAME : RVVBinaryCommon; + def "int_riscv_" # NAME # "_mask" : RVVBinaryCommonMask; + } + + defm vadd : riscv_binary; + } // TargetPrefix = "riscv" diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -87,6 +87,8 @@ explicit RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI); + const RISCVSubtarget &getSubtarget() const { return Subtarget; } + bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2038,6 +2038,15 @@ RISCV::F10_D, RISCV::F11_D, RISCV::F12_D, RISCV::F13_D, RISCV::F14_D, RISCV::F15_D, RISCV::F16_D, RISCV::F17_D }; +static const MCPhysReg ArgVRs[] = { + RISCV::V16, RISCV::V17, RISCV::V18, RISCV::V19, RISCV::V20, + RISCV::V21, RISCV::V22, RISCV::V23 +}; +static const MCPhysReg ArgVRM2s[] = { + RISCV::V16M2, RISCV::V18M2, RISCV::V20M2, RISCV::V22M2 +}; +static const MCPhysReg ArgVRM4s[] = {RISCV::V16M4, RISCV::V20M4}; +static const MCPhysReg ArgVRM8s[] = {RISCV::V16M8}; // Pass a 2*XLEN argument that has been split into two XLEN values through // registers or the stack as necessary. @@ -2082,7 +2091,7 @@ static bool CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo, MVT ValVT, MVT LocVT, CCValAssign::LocInfo LocInfo, ISD::ArgFlagsTy ArgFlags, CCState &State, bool IsFixed, - bool IsRet, Type *OrigTy) { + bool IsRet, Type *OrigTy, const RISCVTargetLowering *TLI) { unsigned XLen = DL.getLargestLegalIntTypeSizeInBits(); assert(XLen == 32 || XLen == 64); MVT XLenVT = XLen == 32 ? MVT::i32 : MVT::i64; @@ -2215,7 +2224,26 @@ Reg = State.AllocateReg(ArgFPR32s); else if (ValVT == MVT::f64 && !UseGPRForF64) Reg = State.AllocateReg(ArgFPR64s); - else + else if (ValVT.isScalableVector()) { + const TargetRegisterClass *RC = TLI->getRegClassFor(ValVT); + if (RC->hasSuperClassEq(&RISCV::VRRegClass)) { + Reg = State.AllocateReg(ArgVRs); + } else if (RC->hasSuperClassEq(&RISCV::VRM2RegClass)) { + Reg = State.AllocateReg(ArgVRM2s); + } else if (RC->hasSuperClassEq(&RISCV::VRM4RegClass)) { + Reg = State.AllocateReg(ArgVRM4s); + } else if (RC->hasSuperClassEq(&RISCV::VRM8RegClass)) { + Reg = State.AllocateReg(ArgVRM8s); + } else { + llvm_unreachable("Unhandled class register for ValueType"); + } + if (!Reg) { + LocInfo = CCValAssign::Indirect; + // Try using a GPR to pass the address + Reg = State.AllocateReg(ArgGPRs); + LocVT = XLenVT; + } + } else Reg = State.AllocateReg(ArgGPRs); unsigned StackOffset = Reg ? 0 : State.AllocateStack(XLen / 8, Align(XLen / 8)); @@ -2238,7 +2266,8 @@ return false; } - assert((!UseGPRForF16_F32 || !UseGPRForF64 || LocVT == XLenVT) && + assert((!UseGPRForF16_F32 || !UseGPRForF64 || LocVT == XLenVT || + (TLI->getSubtarget().hasStdExtV() && ValVT.isScalableVector())) && "Expected an XLenVT at this stage"); if (Reg) { @@ -2274,7 +2303,7 @@ RISCVABI::ABI ABI = MF.getSubtarget().getTargetABI(); if (CC_RISCV(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full, - ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy)) { + ArgFlags, CCInfo, /*IsFixed=*/true, IsRet, ArgTy, this)) { LLVM_DEBUG(dbgs() << "InputArg #" << i << " has unhandled type " << EVT(ArgVT).getEVTString() << '\n'); llvm_unreachable(nullptr); @@ -2295,7 +2324,7 @@ RISCVABI::ABI ABI = MF.getSubtarget().getTargetABI(); if (CC_RISCV(MF.getDataLayout(), ABI, i, ArgVT, ArgVT, CCValAssign::Full, - ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy)) { + ArgFlags, CCInfo, Outs[i].IsFixed, IsRet, OrigTy, this)) { LLVM_DEBUG(dbgs() << "OutputArg #" << i << " has unhandled type " << EVT(ArgVT).getEVTString() << "\n"); llvm_unreachable(nullptr); @@ -2327,29 +2356,34 @@ // The caller is responsible for loading the full value if the argument is // passed with CCValAssign::Indirect. static SDValue unpackFromRegLoc(SelectionDAG &DAG, SDValue Chain, - const CCValAssign &VA, const SDLoc &DL) { + const CCValAssign &VA, const SDLoc &DL, + const RISCVTargetLowering *TLI) { MachineFunction &MF = DAG.getMachineFunction(); MachineRegisterInfo &RegInfo = MF.getRegInfo(); EVT LocVT = VA.getLocVT(); SDValue Val; const TargetRegisterClass *RC; - switch (LocVT.getSimpleVT().SimpleTy) { - default: - llvm_unreachable("Unexpected register type"); - case MVT::i32: - case MVT::i64: - RC = &RISCV::GPRRegClass; - break; - case MVT::f16: - RC = &RISCV::FPR16RegClass; - break; - case MVT::f32: - RC = &RISCV::FPR32RegClass; - break; - case MVT::f64: - RC = &RISCV::FPR64RegClass; - break; + if (LocVT.getSimpleVT().isScalableVector()) { + RC = TLI->getRegClassFor(LocVT.getSimpleVT()); + } else { + switch (LocVT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected register type"); + case MVT::i32: + case MVT::i64: + RC = &RISCV::GPRRegClass; + break; + case MVT::f16: + RC = &RISCV::FPR16RegClass; + break; + case MVT::f32: + RC = &RISCV::FPR32RegClass; + break; + case MVT::f64: + RC = &RISCV::FPR64RegClass; + break; + } } Register VReg = RegInfo.createVirtualRegister(RC); @@ -2623,7 +2657,7 @@ if (VA.getLocVT() == MVT::i32 && VA.getValVT() == MVT::f64) ArgValue = unpackF64OnRV32DSoftABI(DAG, Chain, VA, DL); else if (VA.isRegLoc()) - ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL); + ArgValue = unpackFromRegLoc(DAG, Chain, VA, DL, this); else ArgValue = unpackFromMemLoc(DAG, Chain, VA, DL); @@ -3076,7 +3110,8 @@ ISD::ArgFlagsTy ArgFlags = Outs[i].Flags; RISCVABI::ABI ABI = MF.getSubtarget().getTargetABI(); if (CC_RISCV(MF.getDataLayout(), ABI, i, VT, VT, CCValAssign::Full, - ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr)) + ArgFlags, CCInfo, /*IsFixed=*/true, /*IsRet=*/true, nullptr, + this)) return false; } return true; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -14,6 +14,21 @@ /// //===----------------------------------------------------------------------===// +def NoX0 : SDNodeXForm(N)) { + if (C->isNullValue()) { + return SDValue(CurDAG->getMachineNode(RISCV::ADDI, DL, MVT::i64, + CurDAG->getRegister(RISCV::X0, MVT::i64), + CurDAG->getTargetConstant(0, DL, MVT::i64)), 0); + } + } + + return SDValue(N, 0); +}]>; + //===----------------------------------------------------------------------===// // Utilities. //===----------------------------------------------------------------------===// @@ -50,6 +65,23 @@ // List of EEW. defvar EEWList = [8, 16, 32, 64]; +class swap_helper { + dag Value = !con( + Prefix, + !if(swap, B, A), + !if(swap, A, B), + Suffix); +} + +class ToFPR32 { + dag ret = !cond(!eq(!cast(operand), !cast(FPR64)): + (EXTRACT_SUBREG !dag(type, [FPR64], [name]), sub_32), + !eq(!cast(operand), !cast(FPR16)): + (SUBREG_TO_REG (i16 -1), !dag(type, [FPR16], [name]), sub_16), + !eq(1, 1): + !dag(type, [operand], [name])); +} + //===----------------------------------------------------------------------===// // Vector register and vector group type information. //===----------------------------------------------------------------------===// @@ -227,10 +259,8 @@ def : Pat<(result_type (vop (op_type op_reg_class:$rs1), (op_type op_reg_class:$rs2))), - (instruction (result_type (IMPLICIT_DEF)), - op_reg_class:$rs1, + (instruction op_reg_class:$rs1, op_reg_class:$rs2, - (mask_type zero_reg), VLMax, sew)>; } @@ -244,6 +274,66 @@ vti.LMul, vti.RegClass, vti.RegClass>; } +multiclass pat_intrinsic_binary +{ + defvar inst = !cast(instruction_name#_#kind#"_"# vlmul.MX); + defvar inst_mask = !cast(instruction_name#_#kind#"_"# vlmul.MX#"_MASK"); + + def : Pat<(result_type (!cast(intrinsic_name) + (op1_type op1_reg_class:$rs1), + (op2_type op2_kind:$rs2), + (i64 GPR:$vl))), + (inst (op1_type op1_reg_class:$rs1), + (op2_type op2_kind:$rs2), + (NoX0 GPR:$vl), sew)>; + + def : Pat<(result_type (!cast(intrinsic_name#"_mask") + (result_type result_reg_class:$merge), + (op1_type op1_reg_class:$rs1), + (op2_type op2_kind:$rs2), + (mask_type V0), + (i64 GPR:$vl))), + swap_helper< + (inst_mask result_reg_class:$merge), + (inst_mask op1_reg_class:$rs1), + (inst_mask ToFPR32.ret), + (inst_mask (mask_type V0), (NoX0 GPR:$vl), sew), + swap>.Value>; +} + +multiclass pat_intrinsic_binary_int_v_vv_vx_vi +{ + foreach vti = AllIntegerVectors in + { + defm : pat_intrinsic_binary; + defm : pat_intrinsic_binary; + defm : pat_intrinsic_binary; + } +} + //===----------------------------------------------------------------------===// // Pseudo instructions and patterns. //===----------------------------------------------------------------------===// @@ -354,7 +444,13 @@ // Pseudo instructions. defm PseudoVADD : VPseudoBinary_VV_VX_VI; +//===----------------------------------------------------------------------===// +// Patterns. +//===----------------------------------------------------------------------===// + // Whole-register vector patterns. defm "" : pat_vop_binary_common; +defm "" : pat_intrinsic_binary_int_v_vv_vx_vi<"int_riscv_vadd", "PseudoVADD">; + } // Predicates = [HasStdExtV] diff --git a/llvm/test/CodeGen/RISCV/rvv/vadd.ll b/llvm/test/CodeGen/RISCV/rvv/vadd.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vadd.ll @@ -0,0 +1,24 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+m,+f,+d,+a,+c,+experimental-v \ +; RUN: -verify-machineinstrs --riscv-no-aliases < %s \ +; RUN: | FileCheck %s + +declare @llvm.riscv.vadd.nxv1i8.nxv1i8( + , + , + i64); + +define @test_vadd() nounwind { +; CHECK-LABEL: test_vadd: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, a0, e8,mf8,tu,mu +; CHECK-NEXT: vadd.vv v16, v25, v25 +; CHECK-NEXT: c.jr ra +entry: + %a = call @llvm.riscv.vadd.nxv1i8.nxv1i8( + undef, + undef, + i64 undef) + + ret %a +}