diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -89,7 +89,7 @@ #include "RISCVGenDAGISel.inc" private: - void doPeepholeLoadStoreADDI(); + bool doPeepholeLoadStoreADDI(SDNode *Node); }; namespace RISCV { diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -108,7 +108,78 @@ } void RISCVDAGToDAGISel::PostprocessISelDAG() { - doPeepholeLoadStoreADDI(); + SelectionDAG::allnodes_iterator Position = CurDAG->allnodes_end(); + + bool MadeChange = false; + while (Position != CurDAG->allnodes_begin()) { + SDNode *N = &*--Position; + // Skip dead nodes and any non-machine opcodes. + if (N->use_empty() || !N->isMachineOpcode()) + continue; + + // Try to remove sext.w. + if (N->getMachineOpcode() == RISCV::ADDIW && + isNullConstant(N->getOperand(1))) { + SDValue N0 = N->getOperand(0); + if (!N0.isMachineOpcode()) + continue; + + switch (N0.getMachineOpcode()) { + default: + break; + case RISCV::ADD: + case RISCV::ADDI: + case RISCV::SUB: + case RISCV::MUL: + case RISCV::SLLI: { + // Convert sext.w+add/sub/mul to their W instructions. This will create + // a new independent instruction. This improves latency. + unsigned Opc; + switch (N0.getMachineOpcode()) { + default: + llvm_unreachable("Unexpected opcode!"); + case RISCV::ADD: Opc = RISCV::ADDW; break; + case RISCV::ADDI: Opc = RISCV::ADDIW; break; + case RISCV::SUB: Opc = RISCV::SUBW; break; + case RISCV::MUL: Opc = RISCV::MULW; break; + case RISCV::SLLI: Opc = RISCV::SLLIW; break; + } + + SDValue N00 = N0.getOperand(0); + SDValue N01 = N0.getOperand(1); + + // Shift amount needs to be uimm5. + if (N0.getMachineOpcode() == RISCV::SLLI && + !isUInt<5>(cast(N01)->getSExtValue())) + break; + + SDNode *Result = + CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), + N00, N01); + ReplaceUses(N, Result); + MadeChange = true; + continue; + } + case RISCV::ADDW: + case RISCV::ADDIW: + case RISCV::SUBW: + case RISCV::MULW: + case RISCV::SLLIW: + // Result is already sign extended just remove the sext.w. + // NOTE: We only handle the nodes that are selected with hasAllWUsers. + ReplaceUses(N, N0.getNode()); + MadeChange = true; + continue; + } + + continue; + } + + MadeChange |= doPeepholeLoadStoreADDI(N); + } + + if (MadeChange) + CurDAG->RemoveDeadNodes(); } static SDNode *selectImm(SelectionDAG *CurDAG, const SDLoc &DL, int64_t Imm, @@ -1677,113 +1748,101 @@ // (load (addi base, off1), off2) -> (load base, off1+off2) // (store val, (addi base, off1), off2) -> (store val, base, off1+off2) // This is possible when off1+off2 fits a 12-bit immediate. -void RISCVDAGToDAGISel::doPeepholeLoadStoreADDI() { - SelectionDAG::allnodes_iterator Position(CurDAG->getRoot().getNode()); - ++Position; +bool RISCVDAGToDAGISel::doPeepholeLoadStoreADDI(SDNode *N) { + int OffsetOpIdx; + int BaseOpIdx; - while (Position != CurDAG->allnodes_begin()) { - SDNode *N = &*--Position; - // Skip dead nodes and any non-machine opcodes. - if (N->use_empty() || !N->isMachineOpcode()) - continue; + // Only attempt this optimisation for I-type loads and S-type stores. + switch (N->getMachineOpcode()) { + default: + return false; + case RISCV::LB: + case RISCV::LH: + case RISCV::LW: + case RISCV::LBU: + case RISCV::LHU: + case RISCV::LWU: + case RISCV::LD: + case RISCV::FLH: + case RISCV::FLW: + case RISCV::FLD: + BaseOpIdx = 0; + OffsetOpIdx = 1; + break; + case RISCV::SB: + case RISCV::SH: + case RISCV::SW: + case RISCV::SD: + case RISCV::FSH: + case RISCV::FSW: + case RISCV::FSD: + BaseOpIdx = 1; + OffsetOpIdx = 2; + break; + } - int OffsetOpIdx; - int BaseOpIdx; + if (!isa(N->getOperand(OffsetOpIdx))) + return false; - // Only attempt this optimisation for I-type loads and S-type stores. - switch (N->getMachineOpcode()) { - default: - continue; - case RISCV::LB: - case RISCV::LH: - case RISCV::LW: - case RISCV::LBU: - case RISCV::LHU: - case RISCV::LWU: - case RISCV::LD: - case RISCV::FLH: - case RISCV::FLW: - case RISCV::FLD: - BaseOpIdx = 0; - OffsetOpIdx = 1; - break; - case RISCV::SB: - case RISCV::SH: - case RISCV::SW: - case RISCV::SD: - case RISCV::FSH: - case RISCV::FSW: - case RISCV::FSD: - BaseOpIdx = 1; - OffsetOpIdx = 2; - break; - } + SDValue Base = N->getOperand(BaseOpIdx); - if (!isa(N->getOperand(OffsetOpIdx))) - continue; + // If the base is an ADDI, we can merge it in to the load/store. + if (!Base.isMachineOpcode() || Base.getMachineOpcode() != RISCV::ADDI) + return false; - SDValue Base = N->getOperand(BaseOpIdx); + SDValue ImmOperand = Base.getOperand(1); + uint64_t Offset2 = N->getConstantOperandVal(OffsetOpIdx); - // If the base is an ADDI, we can merge it in to the load/store. - if (!Base.isMachineOpcode() || Base.getMachineOpcode() != RISCV::ADDI) - continue; + if (auto *Const = dyn_cast(ImmOperand)) { + int64_t Offset1 = Const->getSExtValue(); + int64_t CombinedOffset = Offset1 + Offset2; + if (!isInt<12>(CombinedOffset)) + return false; + ImmOperand = CurDAG->getTargetConstant(CombinedOffset, SDLoc(ImmOperand), + ImmOperand.getValueType()); + } else if (auto *GA = dyn_cast(ImmOperand)) { + // If the off1 in (addi base, off1) is a global variable's address (its + // low part, really), then we can rely on the alignment of that variable + // to provide a margin of safety before off1 can overflow the 12 bits. + // Check if off2 falls within that margin; if so off1+off2 can't overflow. + const DataLayout &DL = CurDAG->getDataLayout(); + Align Alignment = GA->getGlobal()->getPointerAlignment(DL); + if (Offset2 != 0 && Alignment <= Offset2) + return false; + int64_t Offset1 = GA->getOffset(); + int64_t CombinedOffset = Offset1 + Offset2; + ImmOperand = CurDAG->getTargetGlobalAddress( + GA->getGlobal(), SDLoc(ImmOperand), ImmOperand.getValueType(), + CombinedOffset, GA->getTargetFlags()); + } else if (auto *CP = dyn_cast(ImmOperand)) { + // Ditto. + Align Alignment = CP->getAlign(); + if (Offset2 != 0 && Alignment <= Offset2) + return false; + int64_t Offset1 = CP->getOffset(); + int64_t CombinedOffset = Offset1 + Offset2; + ImmOperand = CurDAG->getTargetConstantPool( + CP->getConstVal(), ImmOperand.getValueType(), CP->getAlign(), + CombinedOffset, CP->getTargetFlags()); + } else { + return false; + } - SDValue ImmOperand = Base.getOperand(1); - uint64_t Offset2 = N->getConstantOperandVal(OffsetOpIdx); + LLVM_DEBUG(dbgs() << "Folding add-immediate into mem-op:\nBase: "); + LLVM_DEBUG(Base->dump(CurDAG)); + LLVM_DEBUG(dbgs() << "\nN: "); + LLVM_DEBUG(N->dump(CurDAG)); + LLVM_DEBUG(dbgs() << "\n"); - if (auto *Const = dyn_cast(ImmOperand)) { - int64_t Offset1 = Const->getSExtValue(); - int64_t CombinedOffset = Offset1 + Offset2; - if (!isInt<12>(CombinedOffset)) - continue; - ImmOperand = CurDAG->getTargetConstant(CombinedOffset, SDLoc(ImmOperand), - ImmOperand.getValueType()); - } else if (auto *GA = dyn_cast(ImmOperand)) { - // If the off1 in (addi base, off1) is a global variable's address (its - // low part, really), then we can rely on the alignment of that variable - // to provide a margin of safety before off1 can overflow the 12 bits. - // Check if off2 falls within that margin; if so off1+off2 can't overflow. - const DataLayout &DL = CurDAG->getDataLayout(); - Align Alignment = GA->getGlobal()->getPointerAlignment(DL); - if (Offset2 != 0 && Alignment <= Offset2) - continue; - int64_t Offset1 = GA->getOffset(); - int64_t CombinedOffset = Offset1 + Offset2; - ImmOperand = CurDAG->getTargetGlobalAddress( - GA->getGlobal(), SDLoc(ImmOperand), ImmOperand.getValueType(), - CombinedOffset, GA->getTargetFlags()); - } else if (auto *CP = dyn_cast(ImmOperand)) { - // Ditto. - Align Alignment = CP->getAlign(); - if (Offset2 != 0 && Alignment <= Offset2) - continue; - int64_t Offset1 = CP->getOffset(); - int64_t CombinedOffset = Offset1 + Offset2; - ImmOperand = CurDAG->getTargetConstantPool( - CP->getConstVal(), ImmOperand.getValueType(), CP->getAlign(), - CombinedOffset, CP->getTargetFlags()); - } else { - continue; - } + // Modify the offset operand of the load/store. + if (BaseOpIdx == 0) // Load + CurDAG->UpdateNodeOperands(N, Base.getOperand(0), ImmOperand, + N->getOperand(2)); + else // Store + CurDAG->UpdateNodeOperands(N, N->getOperand(0), Base.getOperand(0), + ImmOperand, N->getOperand(3)); - LLVM_DEBUG(dbgs() << "Folding add-immediate into mem-op:\nBase: "); - LLVM_DEBUG(Base->dump(CurDAG)); - LLVM_DEBUG(dbgs() << "\nN: "); - LLVM_DEBUG(N->dump(CurDAG)); - LLVM_DEBUG(dbgs() << "\n"); - - // Modify the offset operand of the load/store. - if (BaseOpIdx == 0) // Load - CurDAG->UpdateNodeOperands(N, Base.getOperand(0), ImmOperand, - N->getOperand(2)); - else // Store - CurDAG->UpdateNodeOperands(N, N->getOperand(0), Base.getOperand(0), - ImmOperand, N->getOperand(3)); - - // The add-immediate may now be dead, in which case remove it. - if (Base.getNode()->use_empty()) - CurDAG->RemoveDeadNode(Base.getNode()); - } + return true; } // This pass converts a legalized DAG into a RISCV-specific DAG, ready diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -1260,14 +1260,6 @@ /// ALU operations -def : Pat<(sext_inreg (add GPR:$rs1, GPR:$rs2), i32), - (ADDW GPR:$rs1, GPR:$rs2)>; -def : Pat<(sext_inreg (add GPR:$rs1, simm12:$imm12), i32), - (ADDIW GPR:$rs1, simm12:$imm12)>; -def : Pat<(sext_inreg (sub GPR:$rs1, GPR:$rs2), i32), - (SUBW GPR:$rs1, GPR:$rs2)>; -def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32), - (SLLIW GPR:$rs1, uimm5:$shamt)>; def : Pat<(i64 (srl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)), (SRLIW GPR:$rs1, uimm5:$shamt)>; def : Pat<(i64 (srl (shl GPR:$rs1, (i64 32)), uimm6gt32:$shamt)), @@ -1328,7 +1320,8 @@ (AddiPairImmA GPR:$rs2))>; let Predicates = [IsRV64] in { -def : Pat<(sext_inreg (add_oneuse GPR:$rs1, (AddiPair:$rs2)), i32), +// Select W instructions if only the lower 32-bits of the result are used. +def : Pat<(addw GPR:$rs1, (AddiPair:$rs2)), (ADDIW (ADDIW GPR:$rs1, (AddiPairImmB AddiPair:$rs2)), (AddiPairImmA AddiPair:$rs2))>; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td @@ -77,9 +77,6 @@ }]>; let Predicates = [HasStdExtM, IsRV64] in { -def : Pat<(sext_inreg (mul GPR:$rs1, GPR:$rs2), i32), - (MULW GPR:$rs1, GPR:$rs2)>; - // Select W instructions without sext_inreg if only the lower 32-bits of the // result are used. def : Pat<(mulw GPR:$rs1, GPR:$rs2), (MULW GPR:$rs1, GPR:$rs2)>; @@ -114,11 +111,4 @@ // still be better off shifting both left by 32. def : Pat<(i64 (mul (and GPR:$rs1, 0xffffffff), (and GPR:$rs2, 0xffffffff))), (MULHU (SLLI GPR:$rs1, 32), (SLLI GPR:$rs2, 32))>; -// Prevent matching the first part of this pattern to mulw. The mul here has -// additionals users or the ANDs would have been removed. The above pattern -// will be used for the other users. If we form a mulw we'll keep the ANDs alive -// and they'll still become SLLI+SRLI. -def : Pat<(sext_inreg (mul (and GPR:$rs1, 0xffffffff), - (and GPR:$rs2, 0xffffffff)), i32), - (ADDIW (MULHU (SLLI GPR:$rs1, 32), (SLLI GPR:$rs2, 32)), 0)>; } // Predicates = [HasStdExtM, IsRV64, NotHasStdExtZba] diff --git a/llvm/test/CodeGen/RISCV/add-imm.ll b/llvm/test/CodeGen/RISCV/add-imm.ll --- a/llvm/test/CodeGen/RISCV/add-imm.ll +++ b/llvm/test/CodeGen/RISCV/add-imm.ll @@ -178,9 +178,8 @@ ; ; RV64I-LABEL: add32_sext_reject_on_rv64: ; RV64I: # %bb.0: -; RV64I-NEXT: lui a1, 1 -; RV64I-NEXT: addiw a1, a1, -1096 -; RV64I-NEXT: addw a0, a0, a1 +; RV64I-NEXT: addiw a0, a0, 1500 +; RV64I-NEXT: addiw a0, a0, 1500 ; RV64I-NEXT: lui a1, %hi(gv0) ; RV64I-NEXT: sw a0, %lo(gv0)(a1) ; RV64I-NEXT: ret diff --git a/llvm/test/CodeGen/RISCV/xaluo.ll b/llvm/test/CodeGen/RISCV/xaluo.ll --- a/llvm/test/CodeGen/RISCV/xaluo.ll +++ b/llvm/test/CodeGen/RISCV/xaluo.ll @@ -883,12 +883,12 @@ ; RV64ZBA-LABEL: smulo2.i32: ; RV64ZBA: # %bb.0: # %entry ; RV64ZBA-NEXT: sext.w a0, a0 -; RV64ZBA-NEXT: addi a2, zero, 13 -; RV64ZBA-NEXT: mul a3, a0, a2 -; RV64ZBA-NEXT: mulw a0, a0, a2 -; RV64ZBA-NEXT: xor a0, a0, a3 +; RV64ZBA-NEXT: sh1add a2, a0, a0 +; RV64ZBA-NEXT: sh2add a2, a2, a0 +; RV64ZBA-NEXT: sext.w a0, a2 +; RV64ZBA-NEXT: xor a0, a0, a2 ; RV64ZBA-NEXT: snez a0, a0 -; RV64ZBA-NEXT: sw a3, 0(a1) +; RV64ZBA-NEXT: sw a2, 0(a1) ; RV64ZBA-NEXT: ret entry: %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 13)