diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -5964,6 +5964,24 @@ return SDValue(N, 0); } +// Helper to call SimplifyDemandedBits on an operand of N where only some low +// bits are demanded. N will be added to the Worklist if it was not deleted. +// Caller should return SDValue(N, 0) if this returns true. +static bool SimplifyDemandedLowBitsHelper(SDNode *N, unsigned OpNo, + unsigned LowBits, + TargetLowering::DAGCombinerInfo &DCI, + const TargetLowering &TLI) { + SDValue Op = N->getOperand(OpNo); + APInt Mask = APInt::getLowBitsSet(Op.getValueSizeInBits(), LowBits); + if (TLI.SimplifyDemandedBits(Op, Mask, DCI)) { + if (N->getOpcode() != ISD::DELETED_NODE) + DCI.AddToWorklist(N); + return true; + } + + return false; +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -6019,136 +6037,85 @@ case RISCVISD::ROLW: case RISCVISD::RORW: { // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5); - if (SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI) || - SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 0, 32, DCI, *this) || + SimplifyDemandedLowBitsHelper(N, 1, 5, DCI, *this)) return SDValue(N, 0); - } break; } case RISCVISD::CLZW: case RISCVISD::CTZW: { // Only the lower 32 bits of the first operand are read - SDValue Op0 = N->getOperand(0); - APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32); - if (SimplifyDemandedBits(Op0, Mask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 0, 32, DCI, *this)) return SDValue(N, 0); - } break; } case RISCVISD::FSL: case RISCVISD::FSR: { // Only the lower log2(Bitwidth)+1 bits of the the shift amount are read. - SDValue ShAmt = N->getOperand(2); - unsigned BitWidth = ShAmt.getValueSizeInBits(); + unsigned BitWidth = N->getOperand(2).getValueSizeInBits(); assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); - APInt ShAmtMask(BitWidth, (BitWidth * 2) - 1); - if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 2, Log2_32(BitWidth) + 1, DCI, *this)) return SDValue(N, 0); - } break; } case RISCVISD::FSLW: case RISCVISD::FSRW: { // Only the lower 32 bits of Values and lower 6 bits of shift amount are // read. - SDValue Op0 = N->getOperand(0); - SDValue Op1 = N->getOperand(1); - SDValue ShAmt = N->getOperand(2); - APInt OpMask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32); - APInt ShAmtMask = APInt::getLowBitsSet(ShAmt.getValueSizeInBits(), 6); - if (SimplifyDemandedBits(Op0, OpMask, DCI) || - SimplifyDemandedBits(Op1, OpMask, DCI) || - SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 0, 32, DCI, *this) || + SimplifyDemandedLowBitsHelper(N, 1, 32, DCI, *this) || + SimplifyDemandedLowBitsHelper(N, 2, 6, DCI, *this)) return SDValue(N, 0); - } break; } case RISCVISD::GREV: case RISCVISD::GORC: { // Only the lower log2(Bitwidth) bits of the the shift amount are read. - SDValue ShAmt = N->getOperand(1); - unsigned BitWidth = ShAmt.getValueSizeInBits(); + unsigned BitWidth = N->getOperand(1).getValueSizeInBits(); assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); - APInt ShAmtMask(BitWidth, BitWidth - 1); - if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 1, Log2_32(BitWidth), DCI, *this)) return SDValue(N, 0); - } return combineGREVI_GORCI(N, DCI.DAG); } case RISCVISD::GREVW: case RISCVISD::GORCW: { // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5); - if (SimplifyDemandedBits(LHS, LHSMask, DCI) || - SimplifyDemandedBits(RHS, RHSMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 0, 32, DCI, *this) || + SimplifyDemandedLowBitsHelper(N, 1, 5, DCI, *this)) return SDValue(N, 0); - } return combineGREVI_GORCI(N, DCI.DAG); } case RISCVISD::SHFL: case RISCVISD::UNSHFL: { - // Only the lower log2(Bitwidth) bits of the the shift amount are read. - SDValue ShAmt = N->getOperand(1); - unsigned BitWidth = ShAmt.getValueSizeInBits(); + // Only the lower log2(Bitwidth)-1 bits of the the shift amount are read. + unsigned BitWidth = N->getOperand(1).getValueSizeInBits(); assert(isPowerOf2_32(BitWidth) && "Unexpected bit width"); - APInt ShAmtMask(BitWidth, (BitWidth / 2) - 1); - if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 1, Log2_32(BitWidth) - 1, DCI, *this)) return SDValue(N, 0); - } break; } case RISCVISD::SHFLW: case RISCVISD::UNSHFLW: { - // Only the lower 32 bits of LHS and lower 5 bits of RHS are read. + // Only the lower 32 bits of LHS and lower 4 bits of RHS are read. SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 4); - if (SimplifyDemandedBits(LHS, LHSMask, DCI) || - SimplifyDemandedBits(RHS, RHSMask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 0, 32, DCI, *this) || + SimplifyDemandedLowBitsHelper(N, 1, 4, DCI, *this)) return SDValue(N, 0); - } break; } case RISCVISD::BCOMPRESSW: case RISCVISD::BDECOMPRESSW: { // Only the lower 32 bits of LHS and RHS are read. - SDValue LHS = N->getOperand(0); - SDValue RHS = N->getOperand(1); - APInt Mask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32); - if (SimplifyDemandedBits(LHS, Mask, DCI) || - SimplifyDemandedBits(RHS, Mask, DCI)) { - if (N->getOpcode() != ISD::DELETED_NODE) - DCI.AddToWorklist(N); + if (SimplifyDemandedLowBitsHelper(N, 0, 32, DCI, *this) || + SimplifyDemandedLowBitsHelper(N, 1, 32, DCI, *this)) return SDValue(N, 0); - } break; }