diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2955,27 +2955,35 @@ return true; } +static SDValue findVSplat(SDValue N) { + SDValue Splat = N; + if (Splat.getOpcode() != RISCVISD::VMV_V_X_VL || + !Splat.getOperand(0).isUndef()) + return SDValue(); + assert(Splat.getNumOperands() == 3 && "Unexpected number of operands"); + return Splat; +} + bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) { - if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !N.getOperand(0).isUndef()) + SDValue Splat = findVSplat(N); + if (!Splat) return false; - assert(N.getNumOperands() == 3 && "Unexpected number of operands"); - SplatVal = N.getOperand(1); + + SplatVal = Splat.getOperand(1); return true; } -using ValidateFn = bool (*)(int64_t); - -static bool selectVSplatSimmHelper(SDValue N, SDValue &SplatVal, - SelectionDAG &DAG, - const RISCVSubtarget &Subtarget, - ValidateFn ValidateImm) { - if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !N.getOperand(0).isUndef() || - !isa(N.getOperand(1))) +static bool selectVSplatImmHelper(SDValue N, SDValue &SplatVal, + SelectionDAG &DAG, + const RISCVSubtarget &Subtarget, + std::function ValidateImm) { + SDValue Splat = findVSplat(N); + if (!Splat || !isa(Splat.getOperand(1))) return false; - assert(N.getNumOperands() == 3 && "Unexpected number of operands"); - int64_t SplatImm = - cast(N.getOperand(1))->getSExtValue(); + const unsigned SplatEltSize = Splat.getScalarValueSizeInBits(); + assert(Subtarget.getXLenVT() == Splat.getOperand(1).getSimpleValueType() && + "Unexpected splat operand type"); // The semantics of RISCVISD::VMV_V_X_VL is that when the operand // type is wider than the resulting vector element type: an implicit @@ -2984,34 +2992,31 @@ // any zero-extended immediate. // For example, we wish to match (i8 -1) -> (XLenVT 255) as a simm5 by first // sign-extending to (XLenVT -1). - MVT XLenVT = Subtarget.getXLenVT(); - assert(XLenVT == N.getOperand(1).getSimpleValueType() && - "Unexpected splat operand type"); - MVT EltVT = N.getSimpleValueType().getVectorElementType(); - if (EltVT.bitsLT(XLenVT)) - SplatImm = SignExtend64(SplatImm, EltVT.getSizeInBits()); + APInt SplatConst = Splat.getConstantOperandAPInt(1).sextOrTrunc(SplatEltSize); + + int64_t SplatImm = SplatConst.getSExtValue(); if (!ValidateImm(SplatImm)) return false; - SplatVal = DAG.getTargetConstant(SplatImm, SDLoc(N), XLenVT); + SplatVal = DAG.getTargetConstant(SplatImm, SDLoc(N), Subtarget.getXLenVT()); return true; } bool RISCVDAGToDAGISel::selectVSplatSimm5(SDValue N, SDValue &SplatVal) { - return selectVSplatSimmHelper(N, SplatVal, *CurDAG, *Subtarget, - [](int64_t Imm) { return isInt<5>(Imm); }); + return selectVSplatImmHelper(N, SplatVal, *CurDAG, *Subtarget, + [](int64_t Imm) { return isInt<5>(Imm); }); } bool RISCVDAGToDAGISel::selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal) { - return selectVSplatSimmHelper( + return selectVSplatImmHelper( N, SplatVal, *CurDAG, *Subtarget, [](int64_t Imm) { return (isInt<5>(Imm) && Imm != -16) || Imm == 16; }); } bool RISCVDAGToDAGISel::selectVSplatSimm5Plus1NonZero(SDValue N, SDValue &SplatVal) { - return selectVSplatSimmHelper( + return selectVSplatImmHelper( N, SplatVal, *CurDAG, *Subtarget, [](int64_t Imm) { return Imm != 0 && ((isInt<5>(Imm) && Imm != -16) || Imm == 16); }); @@ -3019,20 +3024,9 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits, SDValue &SplatVal) { - if (N.getOpcode() != RISCVISD::VMV_V_X_VL || !N.getOperand(0).isUndef() || - !isa(N.getOperand(1))) - return false; - - int64_t SplatImm = - cast(N.getOperand(1))->getSExtValue(); - - if (!isUIntN(Bits, SplatImm)) - return false; - - SplatVal = - CurDAG->getTargetConstant(SplatImm, SDLoc(N), Subtarget->getXLenVT()); - - return true; + return selectVSplatImmHelper( + N, SplatVal, *CurDAG, *Subtarget, + [Bits](int64_t Imm) { return isUIntN(Bits, Imm); }); } bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {