diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -940,6 +940,8 @@ setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND, ISD::OR, ISD::XOR}); + if (Subtarget.is64Bit()) + setTargetDAGCombine(ISD::SRA); if (Subtarget.hasStdExtF()) setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM}); @@ -8527,6 +8529,32 @@ return Opcode; } + +// Combine (sra (shl X, 32), 32 - C) -> (sra (sext_inreg X, i32), C) +static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(N->getOpcode() == ISD::SRA && "Unexpected opcode"); + + if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit()) + return SDValue(); + + auto *C = dyn_cast(N->getOperand(1)); + if (!C || C->getZExtValue() >= 32) + return SDValue(); + + SDValue N0 = N->getOperand(0); + if (N0.getOpcode() != ISD::SHL || !N0.hasOneUse() || + !isa(N0.getOperand(1)) || + N0.getConstantOperandVal(1) != 32) + return SDValue(); + + SDLoc DL(N); + SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, + N0.getOperand(0), DAG.getValueType(MVT::i32)); + return DAG.getNode(ISD::SHL, DL, MVT::i64, SExt, + DAG.getConstant(32 - C->getZExtValue(), DL, MVT::i64)); +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -9003,6 +9031,9 @@ break; } case ISD::SRA: + if (SDValue V = performSRACombine(N, DAG, Subtarget)) + return V; + LLVM_FALLTHROUGH; case ISD::SRL: case ISD::SHL: { SDValue ShAmt = N->getOperand(1); diff --git a/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll b/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll --- a/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll +++ b/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll @@ -7,8 +7,8 @@ define i64 @test1(i64 %a) nounwind { ; RV64I-LABEL: test1: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a0, a0, 32 -; RV64I-NEXT: srai a0, a0, 30 +; RV64I-NEXT: sext.w a0, a0 +; RV64I-NEXT: slli a0, a0, 2 ; RV64I-NEXT: ret %1 = shl i64 %a, 32 %2 = ashr i64 %1, 30 @@ -18,8 +18,7 @@ define i64 @test2(i32 signext %a) nounwind { ; RV64I-LABEL: test2: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a0, a0, 32 -; RV64I-NEXT: srai a0, a0, 29 +; RV64I-NEXT: slli a0, a0, 3 ; RV64I-NEXT: ret %1 = zext i32 %a to i64 %2 = shl i64 %1, 32 @@ -31,8 +30,7 @@ ; RV64I-LABEL: test3: ; RV64I: # %bb.0: ; RV64I-NEXT: lw a0, 0(a0) -; RV64I-NEXT: slli a0, a0, 32 -; RV64I-NEXT: srai a0, a0, 28 +; RV64I-NEXT: slli a0, a0, 4 ; RV64I-NEXT: ret %1 = load i32, i32* %a %2 = zext i32 %1 to i64 @@ -45,8 +43,7 @@ ; RV64I-LABEL: test4: ; RV64I: # %bb.0: ; RV64I-NEXT: addw a0, a0, a1 -; RV64I-NEXT: slli a0, a0, 32 -; RV64I-NEXT: srai a0, a0, 2 +; RV64I-NEXT: slli a0, a0, 30 ; RV64I-NEXT: ret %1 = add i32 %a, %b %2 = zext i32 %1 to i64 @@ -59,8 +56,7 @@ ; RV64I-LABEL: test5: ; RV64I: # %bb.0: ; RV64I-NEXT: xor a0, a0, a1 -; RV64I-NEXT: slli a0, a0, 32 -; RV64I-NEXT: srai a0, a0, 1 +; RV64I-NEXT: slli a0, a0, 31 ; RV64I-NEXT: ret %1 = xor i32 %a, %b %2 = zext i32 %1 to i64 @@ -73,8 +69,7 @@ ; RV64I-LABEL: test6: ; RV64I: # %bb.0: ; RV64I-NEXT: sllw a0, a0, a1 -; RV64I-NEXT: slli a0, a0, 32 -; RV64I-NEXT: srai a0, a0, 16 +; RV64I-NEXT: slli a0, a0, 16 ; RV64I-NEXT: ret %1 = shl i32 %a, %b %2 = zext i32 %1 to i64