diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -51,9 +51,19 @@ FMV_X_ANYEXTW_RV64, // READ_CYCLE_WIDE - A read of the 64-bit cycle CSR on a 32-bit target // (returns (Lo, Hi)). It takes a chain operand. - READ_CYCLE_WIDE + READ_CYCLE_WIDE, + // Generalized Reverse and Generalized Or-Combine - directly matching the + // semantics of the named RISC-V instructions. Lowered as custom nodes as + // TableGen chokes when faced with commutative permutations in deeply-nested + // DAGs. Each node takes an input operand and a TargetConstant immediate + // shift amount, and outputs a bit-manipulated version of input. All operands + // are of type XLenVT. + GREVI, + GREVIW, + GORCI, + GORCIW, }; -} +} // namespace RISCVISD class RISCVTargetLowering : public TargetLowering { const RISCVSubtarget &Subtarget; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -259,6 +259,11 @@ // We can use any register for comparisons setHasMultipleConditionRegisters(); + + if (Subtarget.hasStdExtZbp()) { + setTargetDAGCombine(ISD::OR); + setTargetDAGCombine(ISD::SIGN_EXTEND_INREG); + } } EVT RISCVTargetLowering::getSetCCResultType(const DataLayout &DL, LLVMContext &, @@ -1031,6 +1036,153 @@ } } +// A structure to hold one of the bit-manipulation patterns below. Together, a +// SHL and non-SHL pattern may form a bit-manipulation pair on a single source: +// (or (and (shl x, 1), 0xAAAAAAAA), +// (and (srl x, 1), 0x55555555)) +struct RISCVBitmanipPat { + SDValue Op; + unsigned ShAmt; + bool IsSHL; + + bool formsPairWith(const RISCVBitmanipPat &Other) const { + return Op == Other.Op && ShAmt == Other.ShAmt && IsSHL != Other.IsSHL; + } +}; + +// Matches any of the following bit-manipulation patterns: +// (and (shl x, 1), (0x55555555 << 1)) +// (and (srl x, 1), 0x55555555) +// (shl (and x, 0x55555555), 1) +// (srl (and x, (0x55555555 << 1)), 1) +// where the shift amount and mask may vary thus: +// [1] = 0x55555555 / 0xAAAAAAAA +// [2] = 0x33333333 / 0xCCCCCCCC +// [4] = 0x0F0F0F0F / 0xF0F0F0F0 +// [8] = 0x00FF00FF / 0xFF00FF00 +// [16] = 0x0000FFFF / 0xFFFFFFFF +// [32] = 0x00000000FFFFFFFF / 0xFFFFFFFF00000000 (for RV64) +static Optional matchRISCVBitmanipPat(SDValue Op, bool IsW) { + Optional Mask; + // Optionally consume a mask around the shift operation. + if (Op.getOpcode() == ISD::AND && isa(Op.getOperand(1))) { + Mask = Op.getConstantOperandVal(1); + Op = Op.getOperand(0); + } + if (Op.getOpcode() != ISD::SHL && Op.getOpcode() != ISD::SRL) + return None; + bool IsSHL = Op.getOpcode() == ISD::SHL; + + if (!isa(Op.getOperand(1))) + return None; + auto ShAmt = Op.getConstantOperandVal(1); + + if (!isPowerOf2_64(ShAmt)) + return None; + + // These are the unshifted masks which we use to match bit-manipulation + // patterns. They may be shifted left in certain circumstances. + static const uint64_t BitmanipMasks[] = { + 0x5555555555555555ULL, 0x3333333333333333ULL, 0x0F0F0F0F0F0F0F0FULL, + 0x00FF00FF00FF00FFULL, 0x0000FFFF0000FFFFULL, 0x00000000FFFFFFFFULL, + }; + + unsigned MaskIdx = Log2_64(ShAmt); + if (MaskIdx >= array_lengthof(BitmanipMasks)) + return None; + + auto Src = Op.getOperand(0); + + // When matching the W forms, we only expect a 32-bit mask. + unsigned Width = Op.getValueType() == MVT::i64 && !IsW ? 64 : 32; + auto ExpMask = BitmanipMasks[MaskIdx] & maskTrailingOnes(Width); + + // The expected mask is shifted left when the AND is found around SHL + // patterns. + // ((x >> 1) & 0x55555555) + // ((x << 1) & 0xAAAAAAAA) + bool SHLExpMask = IsSHL; + + if (!Mask) { + // Sometimes LLVM keeps the mask as an operand of the shift, typically when + // the mask is all ones: consume that now. + if (Src.getOpcode() == ISD::AND && isa(Src.getOperand(1))) { + Mask = Src.getConstantOperandVal(1); + Src = Src.getOperand(0); + // The expected mask is now in fact shifted left for SRL, so reverse the + // decision. + // ((x & 0xAAAAAAAA) >> 1) + // ((x & 0x55555555) << 1) + SHLExpMask = !SHLExpMask; + } else { + // Use a default shifted mask of all-ones if there's no AND, truncated + // down to the expected width. This simplifies the logic later on. + Mask = maskTrailingOnes(Width); + *Mask &= (IsSHL ? *Mask << ShAmt : *Mask >> ShAmt); + } + } + + if (SHLExpMask) + ExpMask <<= ShAmt; + + if (Mask != ExpMask) + return None; + + return RISCVBitmanipPat{Src, (unsigned)ShAmt, IsSHL}; +} + +// Match the following pattern as a GREVI(W) operation +// (or (BITMANIP_SHL x), (BITMANIP_SRL x)) +static SDValue combineORToGREV(SDValue Op, bool IsW, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (Op.getSimpleValueType() == Subtarget.getXLenVT()) { + auto LHS = matchRISCVBitmanipPat(Op.getOperand(0), IsW); + auto RHS = matchRISCVBitmanipPat(Op.getOperand(1), IsW); + if (LHS && RHS && LHS->formsPairWith(*RHS)) { + SDLoc DL(Op); + return DAG.getNode( + IsW ? RISCVISD::GREVIW : RISCVISD::GREVI, DL, Op.getValueType(), + LHS->Op, + DAG.getTargetConstant(LHS->ShAmt, DL, Subtarget.getXLenVT())); + } + } + return SDValue(); +} + +// Match the following pattern as a GORCI(W) operation +// (or (or (BITMANIP_SHL x), x), +// (BITMANIP_SRL x)) +static SDValue combineORToGORC(SDValue Op, bool IsW, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + if (Op.getSimpleValueType() == Subtarget.getXLenVT()) { + SDValue Op0 = Op.getOperand(0); + SDValue Op1 = Op.getOperand(1); + // OR is commutable so canonicalize its OR operand to the left + if (Op0.getOpcode() != ISD::OR && Op1.getOpcode() == ISD::OR) + std::swap(Op0, Op1); + if (Op0.getOpcode() != ISD::OR) + return SDValue(); + SDValue OrOp0 = Op0.getOperand(0); + SDValue OrOp1 = Op0.getOperand(1); + auto LHS = matchRISCVBitmanipPat(OrOp0, IsW); + // OR is commutable so swap the operands and try again: x might have been + // on the left + if (!LHS) { + std::swap(OrOp0, OrOp1); + LHS = matchRISCVBitmanipPat(OrOp0, IsW); + } + auto RHS = matchRISCVBitmanipPat(Op1, IsW); + if (LHS && RHS && LHS->formsPairWith(*RHS) && LHS->Op == OrOp1) { + SDLoc DL(Op); + return DAG.getNode( + IsW ? RISCVISD::GORCIW : RISCVISD::GORCI, DL, Op.getValueType(), + LHS->Op, + DAG.getTargetConstant(LHS->ShAmt, DL, Subtarget.getXLenVT())); + } + } + return SDValue(); +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -1126,6 +1278,26 @@ return DAG.getNode(ISD::AND, DL, MVT::i64, NewFMV, DAG.getConstant(~SignBit, DL, MVT::i64)); } + case ISD::OR: + if (auto GREV = combineORToGREV(SDValue(N, 0), + /*IsW*/ false, DCI.DAG, Subtarget)) + return GREV; + if (auto GORC = combineORToGORC(SDValue(N, 0), + /*IsW*/ false, DCI.DAG, Subtarget)) + return GORC; + break; + case ISD::SIGN_EXTEND_INREG: + if (Subtarget.is64Bit() && + cast(N->getOperand(1))->getVT() == MVT::i32 && + N->getOperand(0).getOpcode() == ISD::OR) { + if (auto GREV = combineORToGREV(N->getOperand(0), /*IsW*/ true, DCI.DAG, + Subtarget)) + return GREV; + if (auto GORC = combineORToGORC(N->getOperand(0), /*IsW*/ true, DCI.DAG, + Subtarget)) + return GORC; + } + break; } return SDValue(); @@ -2641,6 +2813,14 @@ return "RISCVISD::FMV_X_ANYEXTW_RV64"; case RISCVISD::READ_CYCLE_WIDE: return "RISCVISD::READ_CYCLE_WIDE"; + case RISCVISD::GREVI: + return "RISCVISD::GREVI"; + case RISCVISD::GREVIW: + return "RISCVISD::GREVIW"; + case RISCVISD::GORCI: + return "RISCVISD::GORCI"; + case RISCVISD::GORCIW: + return "RISCVISD::GORCIW"; } return nullptr; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoB.td b/llvm/lib/Target/RISCV/RISCVInstrInfoB.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoB.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoB.td @@ -709,63 +709,20 @@ def : Pat<(and (srl GPR:$rs1, uimmlog2xlen:$shamt), (XLenVT 1)), (SBEXTI GPR:$rs1, uimmlog2xlen:$shamt)>; -let Predicates = [HasStdExtZbp, IsRV32] in { -def : Pat<(or (or (and (srl GPR:$rs1, (i32 1)), (i32 0x55555555)), GPR:$rs1), - (and (shl GPR:$rs1, (i32 1)), (i32 0xAAAAAAAA))), - (GORCI GPR:$rs1, (i32 1))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i32 2)), (i32 0x33333333)), GPR:$rs1), - (and (shl GPR:$rs1, (i32 2)), (i32 0xCCCCCCCC))), - (GORCI GPR:$rs1, (i32 2))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i32 4)), (i32 0x0F0F0F0F)), GPR:$rs1), - (and (shl GPR:$rs1, (i32 4)), (i32 0xF0F0F0F0))), - (GORCI GPR:$rs1, (i32 4))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i32 8)), (i32 0x00FF00FF)), GPR:$rs1), - (and (shl GPR:$rs1, (i32 8)), (i32 0xFF00FF00))), - (GORCI GPR:$rs1, (i32 8))>; -def : Pat<(or (or (srl GPR:$rs1, (i32 16)), GPR:$rs1), - (shl GPR:$rs1, (i32 16))), - (GORCI GPR:$rs1, (i32 16))>; -} // Predicates = [HasStdExtZbp, IsRV32] +def SDT_RISCVGREVGORC : SDTypeProfile<1, 2, [SDTCisVT<0, XLenVT>, + SDTCisSameAs<0, 1>, + SDTCisSameAs<1, 2>]>; +def riscv_grevi : SDNode<"RISCVISD::GREVI", SDT_RISCVGREVGORC, []>; +def riscv_greviw : SDNode<"RISCVISD::GREVIW", SDT_RISCVGREVGORC, []>; +def riscv_gorci : SDNode<"RISCVISD::GORCI", SDT_RISCVGREVGORC, []>; +def riscv_gorciw : SDNode<"RISCVISD::GORCIW", SDT_RISCVGREVGORC, []>; -let Predicates = [HasStdExtZbp, IsRV64] in { -def : Pat<(or (or (and (srl GPR:$rs1, (i64 1)), (i64 0x5555555555555555)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAAAAAAAAAA))), - (GORCI GPR:$rs1, (i64 1))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i64 2)), (i64 0x3333333333333333)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCCCCCCCCCC))), - (GORCI GPR:$rs1, (i64 2))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F0F0F0F0F)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0F0F0F0F0))), - (GORCI GPR:$rs1, (i64 4))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF00FF00FF)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00FF00FF00))), - (GORCI GPR:$rs1, (i64 8))>; -def : Pat<(or (or (and (srl GPR:$rs1, (i64 16)), (i64 0x0000FFFF0000FFFF)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 16)), (i64 0xFFFF0000FFFF0000))), - (GORCI GPR:$rs1, (i64 16))>; -def : Pat<(or (or (srl GPR:$rs1, (i64 32)), GPR:$rs1), - (shl GPR:$rs1, (i64 32))), - (GORCI GPR:$rs1, (i64 32))>; -} // Predicates = [HasStdExtZbp, IsRV64] +let Predicates = [HasStdExtZbp] in { +def : Pat<(riscv_grevi GPR:$rs1, timm:$shamt), (GREVI GPR:$rs1, timm:$shamt)>; +def : Pat<(riscv_gorci GPR:$rs1, timm:$shamt), (GORCI GPR:$rs1, timm:$shamt)>; +} // Predicates = [HasStdExtZbp] let Predicates = [HasStdExtZbp, IsRV32] in { -def : Pat<(or (and (shl GPR:$rs1, (i32 1)), (i32 0xAAAAAAAA)), - (and (srl GPR:$rs1, (i32 1)), (i32 0x55555555))), - (GREVI GPR:$rs1, (i32 1))>; -def : Pat<(or (and (shl GPR:$rs1, (i32 2)), (i32 0xCCCCCCCC)), - (and (srl GPR:$rs1, (i32 2)), (i32 0x33333333))), - (GREVI GPR:$rs1, (i32 2))>; -def : Pat<(or (and (shl GPR:$rs1, (i32 4)), (i32 0xF0F0F0F0)), - (and (srl GPR:$rs1, (i32 4)), (i32 0x0F0F0F0F))), - (GREVI GPR:$rs1, (i32 4))>; -def : Pat<(or (and (shl GPR:$rs1, (i32 8)), (i32 0xFF00FF00)), - (and (srl GPR:$rs1, (i32 8)), (i32 0x00FF00FF))), - (GREVI GPR:$rs1, (i32 8))>; def : Pat<(rotr (bswap GPR:$rs1), (i32 16)), (GREVI GPR:$rs1, (i32 8))>; // FIXME: Is grev better than rori? def : Pat<(rotl GPR:$rs1, (i32 16)), (GREVI GPR:$rs1, (i32 16))>; @@ -775,21 +732,6 @@ } // Predicates = [HasStdExtZbp, IsRV32] let Predicates = [HasStdExtZbp, IsRV64] in { -def : Pat<(or (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAAAAAAAAAA)), - (and (srl GPR:$rs1, (i64 1)), (i64 0x5555555555555555))), - (GREVI GPR:$rs1, (i64 1))>; -def : Pat<(or (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCCCCCCCCCC)), - (and (srl GPR:$rs1, (i64 2)), (i64 0x3333333333333333))), - (GREVI GPR:$rs1, (i64 2))>; -def : Pat<(or (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0F0F0F0F0)), - (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F0F0F0F0F))), - (GREVI GPR:$rs1, (i64 4))>; -def : Pat<(or (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00FF00FF00)), - (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF00FF00FF))), - (GREVI GPR:$rs1, (i64 8))>; -def : Pat<(or (and (shl GPR:$rs1, (i64 16)), (i64 0xFFFF0000FFFF0000)), - (and (srl GPR:$rs1, (i64 16)), (i64 0x0000FFFF0000FFFF))), - (GREVI GPR:$rs1, (i64 16))>; // FIXME: Is grev better than rori? def : Pat<(rotl GPR:$rs1, (i64 32)), (GREVI GPR:$rs1, (i64 32))>; def : Pat<(rotr GPR:$rs1, (i64 32)), (GREVI GPR:$rs1, (i64 32))>; @@ -962,55 +904,9 @@ (RORIW GPR:$rs1, uimmlog2xlen:$shamt)>; let Predicates = [HasStdExtZbp, IsRV64] in { -def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 1)), (i64 0x55555555)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAA))), - i32), - (GORCIW GPR:$rs1, (i64 1))>; -def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 2)), (i64 0x33333333)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCC))), - i32), - (GORCIW GPR:$rs1, (i64 2))>; -def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0))), - i32), - (GORCIW GPR:$rs1, (i64 4))>; -def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00))), - i32), - (GORCIW GPR:$rs1, (i64 8))>; -def : Pat<(sext_inreg (or (or (and (srl GPR:$rs1, (i64 16)), (i64 0x0000FFFF)), - GPR:$rs1), - (and (shl GPR:$rs1, (i64 16)), (i64 0xFFFF0000))), - i32), - (GORCIW GPR:$rs1, (i64 16))>; -def : Pat<(sext_inreg (or (or (srl (and GPR:$rs1, (i64 0xFFFF0000)), (i64 16)), - GPR:$rs1), - (shl GPR:$rs1, (i64 16))), i32), - (GORCIW GPR:$rs1, (i64 16))>; - -def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 1)), (i64 0xAAAAAAAA)), - (and (srl GPR:$rs1, (i64 1)), (i64 0x55555555))), - i32), - (GREVIW GPR:$rs1, (i64 1))>; -def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 2)), (i64 0xCCCCCCCC)), - (and (srl GPR:$rs1, (i64 2)), (i64 0x33333333))), - i32), - (GREVIW GPR:$rs1, (i64 2))>; -def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 4)), (i64 0xF0F0F0F0)), - (and (srl GPR:$rs1, (i64 4)), (i64 0x0F0F0F0F))), - i32), - (GREVIW GPR:$rs1, (i64 4))>; -def : Pat<(sext_inreg (or (and (shl GPR:$rs1, (i64 8)), (i64 0xFF00FF00)), - (and (srl GPR:$rs1, (i64 8)), (i64 0x00FF00FF))), - i32), - (GREVIW GPR:$rs1, (i64 8))>; -def : Pat<(sext_inreg (or (shl GPR:$rs1, (i64 16)), - (srl (and GPR:$rs1, 0xFFFF0000), (i64 16))), i32), - (GREVIW GPR:$rs1, (i64 16))>; +def : Pat<(riscv_greviw GPR:$rs1, timm:$shamt), (GREVIW GPR:$rs1, timm:$shamt)>; +def : Pat<(riscv_gorciw GPR:$rs1, timm:$shamt), (GORCIW GPR:$rs1, timm:$shamt)>; + def : Pat<(sra (bswap GPR:$rs1), (i64 32)), (GREVIW GPR:$rs1, (i64 24))>; def : Pat<(sra (bitreverse GPR:$rs1), (i64 32)), (GREVIW GPR:$rs1, (i64 31))>; } // Predicates = [HasStdExtZbp, IsRV64]