diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -718,6 +718,71 @@ break; } + case ISD::MUL: { + // Special case for calculating (mul (and X, C2), C1) where the full product + // fits in XLen bits. We can shift X left by the number of leading zeros in + // C2 and shift C1 left by XLen-lzcnt(C2). This will ensure the final + // product has XLen trailing zeros, putting it in the output of MULHU. This + // can avoid materializing a constant in a register for C2. + + // RHS should be a constant. + auto *N1C = dyn_cast(Node->getOperand(1)); + if (!N1C || !N1C->hasOneUse()) + break; + + // LHS should be an AND with constant. + SDValue N0 = Node->getOperand(0); + if (N0.getOpcode() != ISD::AND || !isa(N0.getOperand(1))) + break; + + uint64_t C2 = cast(N0.getOperand(1))->getZExtValue(); + + // Constant should be a mask. + if (!isMask_64(C2)) + break; + + // This should be the only use of the AND unless we will use + // (SRLI (SLLI X, 32), 32). We don't use a shift pair for other AND + // constants. + if (!N0.hasOneUse() && C2 != UINT64_C(0xFFFFFFFF)) + break; + + // If this can be an ANDI, ZEXT.H or ZEXT.W we don't need to do this + // optimization. + if (isInt<12>(C2) || + (C2 == UINT64_C(0xFFFF) && + (Subtarget->hasStdExtZbb() || Subtarget->hasStdExtZbp())) || + (C2 == UINT64_C(0xFFFFFFFF) && Subtarget->hasStdExtZba())) + break; + + // We need to shift left the AND input and C1 by a total of XLen bits. + + // How far left do we need to shift the AND input? + unsigned XLen = Subtarget->getXLen(); + unsigned LeadingZeros = XLen - (64 - countLeadingZeros(C2)); + + // The constant gets shifted by the remaining amount unless that would + // shift bits out. + uint64_t C1 = N1C->getZExtValue(); + unsigned ConstantShift = XLen - LeadingZeros; + if (ConstantShift > (XLen - (64 - countLeadingZeros(C1)))) + break; + + uint64_t ShiftedC1 = C1 << ConstantShift; + // If this RV32, we need to sign extend the constant. + if (XLen == 32) + ShiftedC1 = SignExtend64(ShiftedC1, 32); + + // Create (mulhu (slli X, lzcnt(C2)), C1 << (XLen - lzcnt(C2))). + SDNode *Imm = selectImm(CurDAG, DL, ShiftedC1, *Subtarget); + SDNode *SLLI = + CurDAG->getMachineNode(RISCV::SLLI, DL, VT, N0.getOperand(0), + CurDAG->getTargetConstant(LeadingZeros, DL, VT)); + SDNode *MULHU = CurDAG->getMachineNode(RISCV::MULHU, DL, VT, + SDValue(SLLI, 0), SDValue(Imm, 0)); + ReplaceNode(Node, MULHU); + return; + } case ISD::INTRINSIC_WO_CHAIN: { unsigned IntNo = Node->getConstantOperandVal(0); switch (IntNo) { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td @@ -96,14 +96,6 @@ (REMW GPR:$rs1, GPR:$rs2)>; } // Predicates = [HasStdExtM, IsRV64] -// Pattern to detect constants with no more than 32 active bits that can't -// be materialized with lui+addiw. -def uimm32_not_simm32 : PatLeaf<(XLenVT GPR:$a), [{ - auto *C = dyn_cast(N); - return C && C->hasOneUse() && isUInt<32>(C->getZExtValue()) && - !isInt<32>(C->getSExtValue()); -}]>; - let Predicates = [HasStdExtM, IsRV64, NotHasStdExtZba] in { // Special case for calculating the full 64-bit product of a 32x32 unsigned // multiply where the inputs aren't known to be zero extended. We can shift the @@ -111,9 +103,4 @@ // zeroing the upper 32 bits. def : Pat<(i64 (mul (and GPR:$rs1, 0xffffffff), (and GPR:$rs2, 0xffffffff))), (MULHU (SLLI GPR:$rs1, 32), (SLLI GPR:$rs2, 32))>; -// The RHS could also be a constant that is hard to materialize. By shifting -// left we can allow constant materialization to use LUI+ADDIW via -// hasAllWUsers. -def : Pat<(i64 (mul (and GPR:$rs1, 0xffffffff), uimm32_not_simm32:$rs2)), - (MULHU (SLLI GPR:$rs1, 32), (SLLI GPR:$rs2, 32))>; } // Predicates = [HasStdExtM, IsRV64, NotHasStdExtZba] diff --git a/llvm/test/CodeGen/RISCV/div.ll b/llvm/test/CodeGen/RISCV/div.ll --- a/llvm/test/CodeGen/RISCV/div.ll +++ b/llvm/test/CodeGen/RISCV/div.ll @@ -498,12 +498,9 @@ ; ; RV32IM-LABEL: udiv16_constant: ; RV32IM: # %bb.0: -; RV32IM-NEXT: lui a1, 16 -; RV32IM-NEXT: addi a1, a1, -1 -; RV32IM-NEXT: and a0, a0, a1 -; RV32IM-NEXT: lui a1, 13 -; RV32IM-NEXT: addi a1, a1, -819 -; RV32IM-NEXT: mul a0, a0, a1 +; RV32IM-NEXT: slli a0, a0, 16 +; RV32IM-NEXT: lui a1, 838864 +; RV32IM-NEXT: mulhu a0, a0, a1 ; RV32IM-NEXT: srli a0, a0, 18 ; RV32IM-NEXT: ret ; @@ -522,12 +519,10 @@ ; ; RV64IM-LABEL: udiv16_constant: ; RV64IM: # %bb.0: -; RV64IM-NEXT: lui a1, 16 -; RV64IM-NEXT: addiw a1, a1, -1 -; RV64IM-NEXT: and a0, a0, a1 -; RV64IM-NEXT: lui a1, 13 -; RV64IM-NEXT: addiw a1, a1, -819 -; RV64IM-NEXT: mul a0, a0, a1 +; RV64IM-NEXT: lui a1, 52429 +; RV64IM-NEXT: slli a1, a1, 4 +; RV64IM-NEXT: slli a0, a0, 48 +; RV64IM-NEXT: mulhu a0, a0, a1 ; RV64IM-NEXT: srli a0, a0, 18 ; RV64IM-NEXT: ret %1 = udiv i16 %a, 5 diff --git a/llvm/test/CodeGen/RISCV/pr51206.ll b/llvm/test/CodeGen/RISCV/pr51206.ll --- a/llvm/test/CodeGen/RISCV/pr51206.ll +++ b/llvm/test/CodeGen/RISCV/pr51206.ll @@ -21,12 +21,10 @@ ; CHECK-NEXT: lui a2, %hi(global.1) ; CHECK-NEXT: sw a0, %lo(global.1)(a2) ; CHECK-NEXT: mul a0, a0, a1 -; CHECK-NEXT: lui a1, 16 -; CHECK-NEXT: addiw a1, a1, -1 -; CHECK-NEXT: and a1, a0, a1 -; CHECK-NEXT: lui a2, 13 -; CHECK-NEXT: addiw a2, a2, -819 -; CHECK-NEXT: mul a1, a1, a2 +; CHECK-NEXT: slli a1, a0, 48 +; CHECK-NEXT: lui a2, 52429 +; CHECK-NEXT: slli a2, a2, 4 +; CHECK-NEXT: mulhu a1, a1, a2 ; CHECK-NEXT: srli a1, a1, 18 ; CHECK-NEXT: lui a2, %hi(global.3) ; CHECK-NEXT: li a3, 5 diff --git a/llvm/test/CodeGen/RISCV/urem-lkk.ll b/llvm/test/CodeGen/RISCV/urem-lkk.ll --- a/llvm/test/CodeGen/RISCV/urem-lkk.ll +++ b/llvm/test/CodeGen/RISCV/urem-lkk.ll @@ -48,10 +48,10 @@ ; RV64IM-LABEL: fold_urem_positive_odd: ; RV64IM: # %bb.0: ; RV64IM-NEXT: slli a1, a0, 32 -; RV64IM-NEXT: srli a1, a1, 32 ; RV64IM-NEXT: lui a2, 364242 ; RV64IM-NEXT: addiw a2, a2, 777 -; RV64IM-NEXT: mul a1, a1, a2 +; RV64IM-NEXT: slli a2, a2, 32 +; RV64IM-NEXT: mulhu a1, a1, a2 ; RV64IM-NEXT: srli a1, a1, 32 ; RV64IM-NEXT: subw a2, a0, a1 ; RV64IM-NEXT: srliw a2, a2, 1 @@ -179,10 +179,10 @@ ; RV64IM-LABEL: combine_urem_udiv: ; RV64IM: # %bb.0: ; RV64IM-NEXT: slli a1, a0, 32 -; RV64IM-NEXT: srli a1, a1, 32 ; RV64IM-NEXT: lui a2, 364242 ; RV64IM-NEXT: addiw a2, a2, 777 -; RV64IM-NEXT: mul a1, a1, a2 +; RV64IM-NEXT: slli a2, a2, 32 +; RV64IM-NEXT: mulhu a1, a1, a2 ; RV64IM-NEXT: srli a1, a1, 32 ; RV64IM-NEXT: subw a2, a0, a1 ; RV64IM-NEXT: srliw a2, a2, 1 diff --git a/llvm/test/CodeGen/RISCV/xaluo.ll b/llvm/test/CodeGen/RISCV/xaluo.ll --- a/llvm/test/CodeGen/RISCV/xaluo.ll +++ b/llvm/test/CodeGen/RISCV/xaluo.ll @@ -1201,10 +1201,10 @@ ; ; RV64-LABEL: umulo2.i32: ; RV64: # %bb.0: # %entry -; RV64-NEXT: slli a0, a0, 32 -; RV64-NEXT: srli a0, a0, 32 ; RV64-NEXT: li a2, 13 -; RV64-NEXT: mul a2, a0, a2 +; RV64-NEXT: slli a2, a2, 32 +; RV64-NEXT: slli a0, a0, 32 +; RV64-NEXT: mulhu a2, a0, a2 ; RV64-NEXT: srli a0, a2, 32 ; RV64-NEXT: snez a0, a0 ; RV64-NEXT: sw a2, 0(a1)