diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -1174,6 +1174,14 @@ (SH2ADD_UW GPR:$rs1, GPR:$rs2)>; def : Pat<(i64 (add (and (shl GPR:$rs1, (i64 3)), 0x7FFFFFFFF), non_imm12:$rs2)), (SH3ADD_UW GPR:$rs1, GPR:$rs2)>; + +// Use SRLI to clear the LSBs and SHXADD_UW to mask and shift. +def : Pat<(i64 (add (and GPR:$rs1, 0x1FFFFFFFE), non_imm12:$rs2)), + (SH1ADD_UW (SRLI GPR:$rs1, 1), GPR:$rs2)>; +def : Pat<(i64 (add (and GPR:$rs1, 0x3FFFFFFFC), non_imm12:$rs2)), + (SH2ADD_UW (SRLI GPR:$rs1, 2), GPR:$rs2)>; +def : Pat<(i64 (add (and GPR:$rs1, 0x7FFFFFFF8), non_imm12:$rs2)), + (SH3ADD_UW (SRLI GPR:$rs1, 3), GPR:$rs2)>; } // Predicates = [HasStdExtZba, IsRV64] let Predicates = [HasStdExtZbcOrZbkc] in { diff --git a/llvm/test/CodeGen/RISCV/rv64zba.ll b/llvm/test/CodeGen/RISCV/rv64zba.ll --- a/llvm/test/CodeGen/RISCV/rv64zba.ll +++ b/llvm/test/CodeGen/RISCV/rv64zba.ll @@ -1167,15 +1167,22 @@ ; This the IR you get from InstCombine if take the difference of 2 pointers and ; cast is to unsigned before using as an index. define signext i16 @sh1adduw_ptrdiff(i64 %diff, i16* %baseptr) { -; CHECK-LABEL: sh1adduw_ptrdiff: -; CHECK: # %bb.0: -; CHECK-NEXT: li a2, 1 -; CHECK-NEXT: slli a2, a2, 33 -; CHECK-NEXT: addi a2, a2, -2 -; CHECK-NEXT: and a0, a0, a2 -; CHECK-NEXT: add a0, a1, a0 -; CHECK-NEXT: lh a0, 0(a0) -; CHECK-NEXT: ret +; RV64I-LABEL: sh1adduw_ptrdiff: +; RV64I: # %bb.0: +; RV64I-NEXT: li a2, 1 +; RV64I-NEXT: slli a2, a2, 33 +; RV64I-NEXT: addi a2, a2, -2 +; RV64I-NEXT: and a0, a0, a2 +; RV64I-NEXT: add a0, a1, a0 +; RV64I-NEXT: lh a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: sh1adduw_ptrdiff: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: srli a0, a0, 1 +; RV64ZBA-NEXT: sh1add.uw a0, a0, a1 +; RV64ZBA-NEXT: lh a0, 0(a0) +; RV64ZBA-NEXT: ret %ptrdiff = lshr exact i64 %diff, 1 %cast = and i64 %ptrdiff, 4294967295 %ptr = getelementptr inbounds i16, i16* %baseptr, i64 %cast @@ -1184,15 +1191,22 @@ } define signext i32 @sh2adduw_ptrdiff(i64 %diff, i32* %baseptr) { -; CHECK-LABEL: sh2adduw_ptrdiff: -; CHECK: # %bb.0: -; CHECK-NEXT: li a2, 1 -; CHECK-NEXT: slli a2, a2, 34 -; CHECK-NEXT: addi a2, a2, -4 -; CHECK-NEXT: and a0, a0, a2 -; CHECK-NEXT: add a0, a1, a0 -; CHECK-NEXT: lw a0, 0(a0) -; CHECK-NEXT: ret +; RV64I-LABEL: sh2adduw_ptrdiff: +; RV64I: # %bb.0: +; RV64I-NEXT: li a2, 1 +; RV64I-NEXT: slli a2, a2, 34 +; RV64I-NEXT: addi a2, a2, -4 +; RV64I-NEXT: and a0, a0, a2 +; RV64I-NEXT: add a0, a1, a0 +; RV64I-NEXT: lw a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: sh2adduw_ptrdiff: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: srli a0, a0, 2 +; RV64ZBA-NEXT: sh2add.uw a0, a0, a1 +; RV64ZBA-NEXT: lw a0, 0(a0) +; RV64ZBA-NEXT: ret %ptrdiff = lshr exact i64 %diff, 2 %cast = and i64 %ptrdiff, 4294967295 %ptr = getelementptr inbounds i32, i32* %baseptr, i64 %cast @@ -1201,15 +1215,22 @@ } define i64 @sh3adduw_ptrdiff(i64 %diff, i64* %baseptr) { -; CHECK-LABEL: sh3adduw_ptrdiff: -; CHECK: # %bb.0: -; CHECK-NEXT: li a2, 1 -; CHECK-NEXT: slli a2, a2, 35 -; CHECK-NEXT: addi a2, a2, -8 -; CHECK-NEXT: and a0, a0, a2 -; CHECK-NEXT: add a0, a1, a0 -; CHECK-NEXT: ld a0, 0(a0) -; CHECK-NEXT: ret +; RV64I-LABEL: sh3adduw_ptrdiff: +; RV64I: # %bb.0: +; RV64I-NEXT: li a2, 1 +; RV64I-NEXT: slli a2, a2, 35 +; RV64I-NEXT: addi a2, a2, -8 +; RV64I-NEXT: and a0, a0, a2 +; RV64I-NEXT: add a0, a1, a0 +; RV64I-NEXT: ld a0, 0(a0) +; RV64I-NEXT: ret +; +; RV64ZBA-LABEL: sh3adduw_ptrdiff: +; RV64ZBA: # %bb.0: +; RV64ZBA-NEXT: srli a0, a0, 3 +; RV64ZBA-NEXT: sh3add.uw a0, a0, a1 +; RV64ZBA-NEXT: ld a0, 0(a0) +; RV64ZBA-NEXT: ret %ptrdiff = lshr exact i64 %diff, 3 %cast = and i64 %ptrdiff, 4294967295 %ptr = getelementptr inbounds i64, i64* %baseptr, i64 %cast