diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -5968,6 +5968,20 @@ DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; + // Helper to call SimplifyDemandedBits on an operand of N where only some low + // bits are demanded. N will be added to the Worklist if it was not deleted. + // Caller should return SDValue(N, 0) if this returns true. + auto SimplifyDemandedLowBitsHelper = [&](unsigned OpNo, unsigned LowBits) { + SDValue Op = N->getOperand(OpNo); + APInt Mask = APInt::getLowBitsSet(Op.getValueSizeInBits(), LowBits); + if (!SimplifyDemandedBits(Op, Mask, DCI)) + return false; + + if (N->getOpcode() != ISD::DELETED_NODE) + DCI.AddToWorklist(N); + return true; + }; + switch (N->getOpcode()) { default: break; @@ -6019,136 +6033,85 @@ case RISCVISD::ROLW: case RISCVISD::RORW: { // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5); - if (SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI) || - SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(0, 32) || + SimplifyDemandedLowBitsHelper(1, 5)) return SDValue(N, 0); - } break; } case RISCVISD::CLZW: case RISCVISD::CTZW: { // Only the lower 32 bits of the first operand are read - SDValue Op0 = N->getOperand(0); - APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32); - if (SimplifyDemandedBits(Op0, Mask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(0, 32)) return SDValue(N, 0); - } break; } case RISCVISD::FSL: case RISCVISD::FSR: { // Only the lower log2(Bitwidth)+1 bits of the the shift amount are read. - SDValue ShAmt = N->getOperand(2); - unsigned BitWidth = ShAmt.getValueSizeInBits(); + unsigned BitWidth = N->getOperand(2).getValueSizeInBits(); assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); - APInt ShAmtMask(BitWidth, (BitWidth * 2) - 1); - if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(2, Log2_32(BitWidth) + 1)) return SDValue(N, 0); - } break; } case RISCVISD::FSLW: case RISCVISD::FSRW: { // Only the lower 32 bits of Values and lower 6 bits of shift amount are // read. - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue ShAmt = N->getOperand(2); - APInt OpMask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32); - APInt ShAmtMask = APInt::getLowBitsSet(ShAmt.getValueSizeInBits(), 6); - if (SimplifyDemandedBits(Op0, OpMask, DCI) || - SimplifyDemandedBits(Op1, OpMask, DCI) || - SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(0, 32) || + SimplifyDemandedLowBitsHelper(1, 32) || + SimplifyDemandedLowBitsHelper(2, 6)) return SDValue(N, 0); - } break; } case RISCVISD::GREV: case RISCVISD::GORC: { // Only the lower log2(Bitwidth) bits of the the shift amount are read. - SDValue ShAmt = N->getOperand(1); - unsigned BitWidth = ShAmt.getValueSizeInBits(); + unsigned BitWidth = N->getOperand(1).getValueSizeInBits(); assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); - APInt ShAmtMask(BitWidth, BitWidth - 1); - if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth))) return SDValue(N, 0); - } return combineGREVI_GORCI(N, DCI.DAG); } case RISCVISD::GREVW: case RISCVISD::GORCW: { // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5); - if (SimplifyDemandedBits(LHS, LHSMask, DCI) || - SimplifyDemandedBits(RHS, RHSMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(0, 32) || + SimplifyDemandedLowBitsHelper(1, 5)) return SDValue(N, 0); - } return combineGREVI_GORCI(N, DCI.DAG); } case RISCVISD::SHFL: case RISCVISD::UNSHFL: { - // Only the lower log2(Bitwidth) bits of the the shift amount are read. - SDValue ShAmt = N->getOperand(1); - unsigned BitWidth = ShAmt.getValueSizeInBits(); + // Only the lower log2(Bitwidth)-1 bits of the the shift amount are read. + unsigned BitWidth = N->getOperand(1).getValueSizeInBits(); assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); - APInt ShAmtMask(BitWidth, (BitWidth / 2) - 1); - if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(1, Log2_32(BitWidth) - 1)) return SDValue(N, 0); - } break; } case RISCVISD::SHFLW: case RISCVISD::UNSHFLW: { - // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. + // Only the lower 32 bits of LHS and lower 4 bits of RHS are read. SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 4); - if (SimplifyDemandedBits(LHS, LHSMask, DCI) || - SimplifyDemandedBits(RHS, RHSMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(0, 32) || + SimplifyDemandedLowBitsHelper(1, 4)) return SDValue(N, 0); - } break; } case RISCVISD::BCOMPRESSW: case RISCVISD::BDECOMPRESSW: { // Only the lower 32 bits of LHS and RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt Mask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - if (SimplifyDemandedBits(LHS, Mask, DCI) || - SimplifyDemandedBits(RHS, Mask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(0, 32) || + SimplifyDemandedLowBitsHelper(1, 32)) return SDValue(N, 0); - } break; }