diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -8532,6 +8532,10 @@ // Combine (sra (shl X, 32), 32 - C) -> (shl (sext_inreg X, i32), C) // FIXME: Should this be a generic combine? There's a similar combine on X86. +// +// Also try these folds where an add or sub is in the middle. +// (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), C) +// (sra (sub C1, (shl X, 32)), 32 - C) -> (shl (sext_inreg (sub C1, X), C) static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { assert(N->getOpcode() == ISD::SRA && "Unexpected opcode"); @@ -8539,21 +8543,63 @@ if (N->getValueType(0) != MVT::i64 || !Subtarget.is64Bit()) return SDValue(); - auto *C = dyn_cast(N->getOperand(1)); - if (!C || C->getZExtValue() >= 32) + auto *ShAmtC = dyn_cast(N->getOperand(1)); + if (!ShAmtC || ShAmtC->getZExtValue() > 32) return SDValue(); SDValue N0 = N->getOperand(0); - if (N0.getOpcode() != ISD::SHL || !N0.hasOneUse() || - !isa(N0.getOperand(1)) || - N0.getConstantOperandVal(1) != 32) + + SDValue Shl; + ConstantSDNode *AddC = nullptr; + + // We might have an ADD or SUB between the SRA and SHL. + bool IsAdd = N0.getOpcode() == ISD::ADD; + if ((IsAdd || N0.getOpcode() == ISD::SUB)) { + if (!N0.hasOneUse()) + return SDValue(); + // Other operand needs to be a constant we can modify. + AddC = dyn_cast(N0.getOperand(IsAdd ? 1 : 0)); + if (!AddC) + return SDValue(); + + // AddC needs to have at least 32 trailing zeros. + if (AddC->getAPIntValue().countTrailingZeros() < 32) + return SDValue(); + + Shl = N0.getOperand(IsAdd ? 0 : 1); + } else { + // Not an ADD or SUB. + Shl = N0; + } + + // Look for a shift left by 32. + if (Shl.getOpcode() != ISD::SHL || !Shl.hasOneUse() || + !isa(Shl.getOperand(1)) || + Shl.getConstantOperandVal(1) != 32) return SDValue(); SDLoc DL(N); - SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, - N0.getOperand(0), DAG.getValueType(MVT::i32)); - return DAG.getNode(ISD::SHL, DL, MVT::i64, SExt, - DAG.getConstant(32 - C->getZExtValue(), DL, MVT::i64)); + SDValue In = Shl.getOperand(0); + + // If we looked through an ADD or SUB, we need to rebuild it with the shifted + // constant. + if (AddC) { + SDValue ShiftedAddC = + DAG.getConstant(AddC->getAPIntValue().lshr(32), DL, MVT::i64); + if (IsAdd) + In = DAG.getNode(ISD::ADD, DL, MVT::i64, In, ShiftedAddC); + else + In = DAG.getNode(ISD::SUB, DL, MVT::i64, ShiftedAddC, In); + } + + SDValue SExt = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, In, + DAG.getValueType(MVT::i32)); + if (ShAmtC->getZExtValue() == 32) + return SExt; + + return DAG.getNode( + ISD::SHL, DL, MVT::i64, SExt, + DAG.getConstant(32 - ShAmtC->getZExtValue(), DL, MVT::i64)); } SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, diff --git a/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll b/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll --- a/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll +++ b/llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll @@ -84,11 +84,7 @@ define i64 @test7(i32* %0, i64 %1) { ; RV64I-LABEL: test7: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a0, a1, 32 -; RV64I-NEXT: li a1, 1 -; RV64I-NEXT: slli a1, a1, 32 -; RV64I-NEXT: add a0, a0, a1 -; RV64I-NEXT: srai a0, a0, 32 +; RV64I-NEXT: addiw a0, a1, 1 ; RV64I-NEXT: ret %3 = shl i64 %1, 32 %4 = add i64 %3, 4294967296 @@ -102,11 +98,8 @@ define i64 @test8(i32* %0, i64 %1) { ; RV64I-LABEL: test8: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a0, a1, 32 -; RV64I-NEXT: li a1, 1 -; RV64I-NEXT: slli a1, a1, 32 -; RV64I-NEXT: sub a0, a1, a0 -; RV64I-NEXT: srai a0, a0, 32 +; RV64I-NEXT: li a0, 1 +; RV64I-NEXT: subw a0, a0, a1 ; RV64I-NEXT: ret %3 = mul i64 %1, -4294967296 %4 = add i64 %3, 4294967296 @@ -119,11 +112,10 @@ define signext i32 @test9(i32* %0, i64 %1) { ; RV64I-LABEL: test9: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a1, a1, 32 -; RV64I-NEXT: lui a2, 4097 -; RV64I-NEXT: slli a2, a2, 20 -; RV64I-NEXT: add a1, a1, a2 -; RV64I-NEXT: srai a1, a1, 30 +; RV64I-NEXT: lui a2, 1 +; RV64I-NEXT: addiw a2, a2, 1 +; RV64I-NEXT: addw a1, a1, a2 +; RV64I-NEXT: slli a1, a1, 2 ; RV64I-NEXT: add a0, a0, a1 ; RV64I-NEXT: lw a0, 0(a0) ; RV64I-NEXT: ret @@ -140,12 +132,10 @@ define signext i32 @test10(i32* %0, i64 %1) { ; RV64I-LABEL: test10: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a1, a1, 32 ; RV64I-NEXT: lui a2, 30141 ; RV64I-NEXT: addiw a2, a2, -747 -; RV64I-NEXT: slli a2, a2, 32 -; RV64I-NEXT: sub a1, a2, a1 -; RV64I-NEXT: srai a1, a1, 30 +; RV64I-NEXT: subw a1, a2, a1 +; RV64I-NEXT: slli a1, a1, 2 ; RV64I-NEXT: add a0, a0, a1 ; RV64I-NEXT: lw a0, 0(a0) ; RV64I-NEXT: ret @@ -160,11 +150,8 @@ define i64 @test11(i32* %0, i64 %1) { ; RV64I-LABEL: test11: ; RV64I: # %bb.0: -; RV64I-NEXT: slli a0, a1, 32 -; RV64I-NEXT: li a1, -1 -; RV64I-NEXT: slli a1, a1, 63 -; RV64I-NEXT: sub a0, a1, a0 -; RV64I-NEXT: srai a0, a0, 32 +; RV64I-NEXT: lui a0, 524288 +; RV64I-NEXT: subw a0, a0, a1 ; RV64I-NEXT: ret %3 = mul i64 %1, -4294967296 %4 = add i64 %3, 9223372036854775808 ;0x8000'0000'0000'0000