diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -70,6 +70,11 @@ return selectSHXADDOp(N, ShAmt, Val); } + bool selectSHXADD_UWOp(SDValue N, unsigned ShAmt, SDValue &Val); + template bool selectSHXADD_UWOp(SDValue N, SDValue &Val) { + return selectSHXADD_UWOp(N, ShAmt, Val); + } + bool hasAllNBitUsers(SDNode *Node, unsigned Bits) const; bool hasAllHUsers(SDNode *Node) const { return hasAllNBitUsers(Node, 16); } bool hasAllWUsers(SDNode *Node) const { return hasAllNBitUsers(Node, 32); } diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2234,6 +2234,43 @@ return false; } +/// Look for various patterns that can be done with a SHL that can be folded +/// into a SHXADD_UW. \p ShAmt contains 1, 2, or 3 and is set based on which +/// SHXADD_UW we are trying to match. +bool RISCVDAGToDAGISel::selectSHXADD_UWOp(SDValue N, unsigned ShAmt, + SDValue &Val) { + if (N.getOpcode() == ISD::AND && isa(N.getOperand(1)) && + N.hasOneUse()) { + SDValue N0 = N.getOperand(0); + if (N0.getOpcode() == ISD::SHL && isa(N0.getOperand(1)) && + N0.hasOneUse()) { + uint64_t Mask = N.getConstantOperandVal(1); + unsigned C2 = N0.getConstantOperandVal(1); + + Mask &= maskTrailingZeros(C2); + + // Look for (and (shl y, c2), c1) where c1 is a shifted mask with + // 32-ShAmt leading zeros and c2 trailing zeros. We can use SLLI by + // c2-ShAmt followed by SHXADD_UW with ShAmt for the X amount. + if (isShiftedMask_64(Mask)) { + unsigned Leading = countLeadingZeros(Mask); + unsigned Trailing = countTrailingZeros(Mask); + if (Leading == 32 - ShAmt && Trailing == C2 && Trailing > ShAmt) { + SDLoc DL(N); + EVT VT = N.getValueType(); + Val = SDValue(CurDAG->getMachineNode( + RISCV::SLLI, DL, VT, N0.getOperand(0), + CurDAG->getTargetConstant(C2 - ShAmt, DL, VT)), + 0); + return true; + } + } + } + } + + return false; +} + // Return true if all users of this SDNode* only consume the lower \p Bits. // This can be used to form W instructions for add/sub/mul/shl even when the // root isn't a sext_inreg. This can allow the ADDW/SUBW/MULW/SLLIW to CSE if diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -257,6 +257,10 @@ def sh2add_op : ComplexPattern", [], [], 6>; def sh3add_op : ComplexPattern", [], [], 6>; +def sh1add_uw_op : ComplexPattern", [], [], 6>; +def sh2add_uw_op : ComplexPattern", [], [], 6>; +def sh3add_uw_op : ComplexPattern", [], [], 6>; + //===----------------------------------------------------------------------===// // Instruction class templates //===----------------------------------------------------------------------===// @@ -771,6 +775,14 @@ def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)), (SH3ADD_UW GPR:$rs1, GPR:$rs2)>; +// More complex cases use a ComplexPattern. +def : Pat<(add sh1add_uw_op:$rs1, non_imm12:$rs2), + (SH1ADD_UW sh1add_uw_op:$rs1, GPR:$rs2)>; +def : Pat<(add sh2add_uw_op:$rs1, non_imm12:$rs2), + (SH2ADD_UW sh2add_uw_op:$rs1, GPR:$rs2)>; +def : Pat<(add sh3add_uw_op:$rs1, non_imm12:$rs2), + (SH3ADD_UW sh3add_uw_op:$rs1, GPR:$rs2)>; + def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFE), non_imm12:$rs2)), (SH1ADD (SRLIW GPR:$rs1, 1), GPR:$rs2)>; def : Pat<(i64 (add (and GPR:$rs1, 0xFFFFFFFC), non_imm12:$rs2)), diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll --- a/llvm/test/CodeGen/RISCV/rv64zba.ll +++ b/llvm/test/CodeGen/RISCV/rv64zba.ll @@ -1716,3 +1716,69 @@ %5 = load i64, i64* %4, align 8 ret i64 %5 } + +define signext i16 @shl_2_sh1add(i16* %0, i32 signext %1) { +; RV64I-LABEL: shl_2_sh1add: +; RV64I: # %bb.0: +; RV64I-NEXT: slli a1, a1, 34 +; RV64I-NEXT: srli a1, a1, 31 +; RV64I-NEXT: add a0, a0, a1 +; RV64I-NEXT: lh a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: shl_2_sh1add: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: slli a1, a1, 2 +; RV64ZBA-NEXT: sh1add.uw a0, a1, a0 +; RV64ZBA-NEXT: lh a0, 0(a0) +; RV64ZBA-NEXT: ret + %3 = shl i32 %1, 2 + %4 = zext i32 %3 to i64 + %5 = getelementptr inbounds i16, i16* %0, i64 %4 + %6 = load i16, i16* %5, align 2 + ret i16 %6 +} + +define signext i32 @shl_16_sh2add(i32* %0, i32 signext %1) { +; RV64I-LABEL: shl_16_sh2add: +; RV64I: # %bb.0: +; RV64I-NEXT: slli a1, a1, 48 +; RV64I-NEXT: srli a1, a1, 30 +; RV64I-NEXT: add a0, a0, a1 +; RV64I-NEXT: lw a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: shl_16_sh2add: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: slli a1, a1, 16 +; RV64ZBA-NEXT: sh2add.uw a0, a1, a0 +; RV64ZBA-NEXT: lw a0, 0(a0) +; RV64ZBA-NEXT: ret + %3 = shl i32 %1, 16 + %4 = zext i32 %3 to i64 + %5 = getelementptr inbounds i32, i32* %0, i64 %4 + %6 = load i32, i32* %5, align 4 + ret i32 %6 +} + +define i64 @shl_31_sh3add(i64* %0, i32 signext %1) { +; RV64I-LABEL: shl_31_sh3add: +; RV64I: # %bb.0: +; RV64I-NEXT: slli a1, a1, 63 +; RV64I-NEXT: srli a1, a1, 29 +; RV64I-NEXT: add a0, a0, a1 +; RV64I-NEXT: ld a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: shl_31_sh3add: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: slli a1, a1, 31 +; RV64ZBA-NEXT: sh3add.uw a0, a1, a0 +; RV64ZBA-NEXT: ld a0, 0(a0) +; RV64ZBA-NEXT: ret + %3 = shl i32 %1, 31 + %4 = zext i32 %3 to i64 + %5 = getelementptr inbounds i64, i64* %0, i64 %4 + %6 = load i64, i64* %5, align 8 + ret i64 %6 +}