Index: include/llvm/CodeGen/TargetLowering.h =================================================================== --- include/llvm/CodeGen/TargetLowering.h +++ include/llvm/CodeGen/TargetLowering.h @@ -258,6 +258,13 @@ EVT getShiftAmountTy(EVT LHSTy, const DataLayout &DL, bool LegalTypes = true) const; + /// Returns how a promoted shift amount value should be extended (ZERO_EXTEND, + /// SIGN_EXTEND, or ANY_EXTEND). + virtual ISD::NodeType getExtendForShiftAmount(EVT FromShAmtVT, + EVT ToShAmtVT) const { + return ISD::ZERO_EXTEND; + } + /// Returns the type to be used for the index operand of: /// ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT, /// ISD::INSERT_SUBVECTOR, and ISD::EXTRACT_SUBVECTOR Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -2825,9 +2825,9 @@ void SelectionDAGBuilder::visitShift(const User &I, unsigned Opcode) { SDValue Op1 = getValue(I.getOperand(0)); SDValue Op2 = getValue(I.getOperand(1)); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - EVT ShiftTy = DAG.getTargetLoweringInfo().getShiftAmountTy( - Op1.getValueType(), DAG.getDataLayout()); + EVT ShiftTy = TLI.getShiftAmountTy(Op1.getValueType(), DAG.getDataLayout()); // Coerce the shift amount to the right type if we can. if (!I.getType()->isVectorTy() && Op2.getValueType() != ShiftTy) { @@ -2836,8 +2836,11 @@ SDLoc DL = getCurSDLoc(); // If the operand is smaller than the shift count type, promote it. - if (ShiftSize > Op2Size) - Op2 = DAG.getNode(ISD::ZERO_EXTEND, DL, ShiftTy, Op2); + if (ShiftSize > Op2Size) { + ISD::NodeType ExtendKind = + TLI.getExtendForShiftAmount(Op2.getValueType(), ShiftTy); + Op2 = DAG.getNode(ExtendKind, DL, ShiftTy, Op2); + } // If the operand is larger than the shift count type but the shift // count type has enough bits to represent any shift value, truncate Index: lib/Target/RISCV/RISCVISelLowering.h =================================================================== --- lib/Target/RISCV/RISCVISelLowering.h +++ lib/Target/RISCV/RISCVISelLowering.h @@ -43,6 +43,8 @@ explicit RISCVTargetLowering(const TargetMachine &TM, const RISCVSubtarget &STI); + ISD::NodeType getExtendForShiftAmount(EVT FromShAmtVT, + EVT ToShAmtVT) const override; bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, unsigned Intrinsic) const override; Index: lib/Target/RISCV/RISCVISelLowering.cpp =================================================================== --- lib/Target/RISCV/RISCVISelLowering.cpp +++ lib/Target/RISCV/RISCVISelLowering.cpp @@ -155,6 +155,20 @@ setMinimumJumpTableEntries(INT_MAX); } +ISD::NodeType +RISCVTargetLowering::getExtendForShiftAmount(EVT FromShAmtVT, + EVT ToShAmtVT) const { + // RISC-V shift instructions read only the lower 5 or lower 6 bits of the + // shift amount. This means ANY_EXTEND is safe as long as the type being + // extended is at least that width. Be conservative and assume that there's + // still a possibility that the widest supported shift instructions could be + // selected. + unsigned WidestShAmt = Subtarget.is64Bit() ? 6 : 5; + if (FromShAmtVT.getSizeInBits() >= WidestShAmt) + return ISD::ANY_EXTEND; + return ISD::ZERO_EXTEND; +} + EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &, EVT VT) const { if (!VT.isVector()) Index: test/CodeGen/RISCV/alu16.ll =================================================================== --- test/CodeGen/RISCV/alu16.ll +++ test/CodeGen/RISCV/alu16.ll @@ -122,9 +122,6 @@ define i16 @sll(i16 %a, i16 %b) nounwind { ; RV32I-LABEL: sll: ; RV32I: # %bb.0: -; RV32I-NEXT: lui a2, 16 -; RV32I-NEXT: addi a2, a2, -1 -; RV32I-NEXT: and a1, a1, a2 ; RV32I-NEXT: sll a0, a0, a1 ; RV32I-NEXT: ret %1 = shl i16 %a, %b @@ -173,7 +170,6 @@ ; RV32I: # %bb.0: ; RV32I-NEXT: lui a2, 16 ; RV32I-NEXT: addi a2, a2, -1 -; RV32I-NEXT: and a1, a1, a2 ; RV32I-NEXT: and a0, a0, a2 ; RV32I-NEXT: srl a0, a0, a1 ; RV32I-NEXT: ret @@ -184,9 +180,6 @@ define i16 @sra(i16 %a, i16 %b) nounwind { ; RV32I-LABEL: sra: ; RV32I: # %bb.0: -; RV32I-NEXT: lui a2, 16 -; RV32I-NEXT: addi a2, a2, -1 -; RV32I-NEXT: and a1, a1, a2 ; RV32I-NEXT: slli a0, a0, 16 ; RV32I-NEXT: srai a0, a0, 16 ; RV32I-NEXT: sra a0, a0, a1 Index: test/CodeGen/RISCV/alu8.ll =================================================================== --- test/CodeGen/RISCV/alu8.ll +++ test/CodeGen/RISCV/alu8.ll @@ -114,7 +114,6 @@ define i8 @sll(i8 %a, i8 %b) nounwind { ; RV32I-LABEL: sll: ; RV32I: # %bb.0: -; RV32I-NEXT: andi a1, a1, 255 ; RV32I-NEXT: sll a0, a0, a1 ; RV32I-NEXT: ret %1 = shl i8 %a, %b @@ -159,7 +158,6 @@ define i8 @srl(i8 %a, i8 %b) nounwind { ; RV32I-LABEL: srl: ; RV32I: # %bb.0: -; RV32I-NEXT: andi a1, a1, 255 ; RV32I-NEXT: andi a0, a0, 255 ; RV32I-NEXT: srl a0, a0, a1 ; RV32I-NEXT: ret @@ -170,7 +168,6 @@ define i8 @sra(i8 %a, i8 %b) nounwind { ; RV32I-LABEL: sra: ; RV32I: # %bb.0: -; RV32I-NEXT: andi a1, a1, 255 ; RV32I-NEXT: slli a0, a0, 24 ; RV32I-NEXT: srai a0, a0, 24 ; RV32I-NEXT: sra a0, a0, a1