Skip to content

Commit c3b496d

Browse files
committedFeb 20, 2019
[SDAG] Support vector UMULO/SMULO
Second part of https://bugs.llvm.org/show_bug.cgi?id=40442. This adds an extra UnrollVectorOverflowOp() method to SDAG, because the general UnrollOverflowOp() method can't deal with multiple results. Additionally we need to expand UMULO/SMULO during vector op legalization, as it may result in unrolling, which may need additional type legalization. Differential Revision: https://reviews.llvm.org/D57997 llvm-svn: 354513
1 parent 1e25464 commit c3b496d

File tree

10 files changed

+5904
-22
lines changed

10 files changed

+5904
-22
lines changed
 

Diff for: ‎llvm/include/llvm/CodeGen/SelectionDAG.h

+5
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,11 @@ class SelectionDAG {
15351535
/// vector op and fill the end of the resulting vector with UNDEFS.
15361536
SDValue UnrollVectorOp(SDNode *N, unsigned ResNE = 0);
15371537

1538+
/// Like UnrollVectorOp(), but for the [US](ADD|SUB|MUL)O family of opcodes.
1539+
/// This is a separate function because those opcodes have two results.
1540+
std::pair<SDValue, SDValue> UnrollVectorOverflowOp(SDNode *N,
1541+
unsigned ResNE = 0);
1542+
15381543
/// Return true if loads are next to each other and can be
15391544
/// merged. Check that both are nonvolatile and if LD is loading
15401545
/// 'Bytes' bytes from a location that is 'Dist' units away from the

Diff for: ‎llvm/include/llvm/CodeGen/TargetLowering.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -3888,9 +3888,10 @@ class TargetLowering : public TargetLoweringBase {
38883888
/// integers as its arguments.
38893889
SDValue expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const;
38903890

3891-
/// Method for building the DAG expansion of ISD::[US]MULO, returning the two
3892-
/// result values as a pair.
3893-
std::pair<SDValue, SDValue> expandMULO(SDNode *Node, SelectionDAG &DAG) const;
3891+
/// Method for building the DAG expansion of ISD::[US]MULO. Returns whether
3892+
/// expansion was successful and populates the Result and Overflow arguments.
3893+
bool expandMULO(SDNode *Node, SDValue &Result, SDValue &Overflow,
3894+
SelectionDAG &DAG) const;
38943895

38953896
//===--------------------------------------------------------------------===//
38963897
// Instruction Emitting Hooks

Diff for: ‎llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -3321,9 +3321,11 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
33213321
}
33223322
case ISD::UMULO:
33233323
case ISD::SMULO: {
3324-
auto Pair = TLI.expandMULO(Node, DAG);
3325-
Results.push_back(Pair.first);
3326-
Results.push_back(Pair.second);
3324+
SDValue Result, Overflow;
3325+
if (TLI.expandMULO(Node, Result, Overflow, DAG)) {
3326+
Results.push_back(Result);
3327+
Results.push_back(Overflow);
3328+
}
33273329
break;
33283330
}
33293331
case ISD::BUILD_PAIR: {

Diff for: ‎llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class VectorLegalizer {
140140
SDValue ExpandFunnelShift(SDValue Op);
141141
SDValue ExpandROT(SDValue Op);
142142
SDValue ExpandFMINNUM_FMAXNUM(SDValue Op);
143+
SDValue ExpandMULO(SDValue Op);
143144
SDValue ExpandAddSubSat(SDValue Op);
144145
SDValue ExpandFixedPointMul(SDValue Op);
145146
SDValue ExpandStrictFPOp(SDValue Op);
@@ -418,6 +419,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
418419
case ISD::UMAX:
419420
case ISD::SMUL_LOHI:
420421
case ISD::UMUL_LOHI:
422+
case ISD::SMULO:
423+
case ISD::UMULO:
421424
case ISD::FCANONICALIZE:
422425
case ISD::SADDSAT:
423426
case ISD::UADDSAT:
@@ -779,6 +782,9 @@ SDValue VectorLegalizer::Expand(SDValue Op) {
779782
case ISD::FMINNUM:
780783
case ISD::FMAXNUM:
781784
return ExpandFMINNUM_FMAXNUM(Op);
785+
case ISD::UMULO:
786+
case ISD::SMULO:
787+
return ExpandMULO(Op);
782788
case ISD::USUBSAT:
783789
case ISD::SSUBSAT:
784790
case ISD::UADDSAT:
@@ -1216,6 +1222,16 @@ SDValue VectorLegalizer::ExpandFMINNUM_FMAXNUM(SDValue Op) {
12161222
return DAG.UnrollVectorOp(Op.getNode());
12171223
}
12181224

1225+
SDValue VectorLegalizer::ExpandMULO(SDValue Op) {
1226+
SDValue Result, Overflow;
1227+
if (!TLI.expandMULO(Op.getNode(), Result, Overflow, DAG))
1228+
std::tie(Result, Overflow) = DAG.UnrollVectorOverflowOp(Op.getNode());
1229+
1230+
AddLegalizedOperand(Op.getValue(0), Result);
1231+
AddLegalizedOperand(Op.getValue(1), Overflow);
1232+
return Op.getResNo() ? Overflow : Result;
1233+
}
1234+
12191235
SDValue VectorLegalizer::ExpandAddSubSat(SDValue Op) {
12201236
if (SDValue Expanded = TLI.expandAddSubSat(Op.getNode(), DAG))
12211237
return Expanded;

Diff for: ‎llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -8918,6 +8918,50 @@ SDValue SelectionDAG::UnrollVectorOp(SDNode *N, unsigned ResNE) {
89188918
return getBuildVector(VecVT, dl, Scalars);
89198919
}
89208920

8921+
std::pair<SDValue, SDValue> SelectionDAG::UnrollVectorOverflowOp(
8922+
SDNode *N, unsigned ResNE) {
8923+
unsigned Opcode = N->getOpcode();
8924+
assert((Opcode == ISD::UADDO || Opcode == ISD::SADDO ||
8925+
Opcode == ISD::USUBO || Opcode == ISD::SSUBO ||
8926+
Opcode == ISD::UMULO || Opcode == ISD::SMULO) &&
8927+
"Expected an overflow opcode");
8928+
8929+
EVT ResVT = N->getValueType(0);
8930+
EVT OvVT = N->getValueType(1);
8931+
EVT ResEltVT = ResVT.getVectorElementType();
8932+
EVT OvEltVT = OvVT.getVectorElementType();
8933+
SDLoc dl(N);
8934+
8935+
// If ResNE is 0, fully unroll the vector op.
8936+
unsigned NE = ResVT.getVectorNumElements();
8937+
if (ResNE == 0)
8938+
ResNE = NE;
8939+
else if (NE > ResNE)
8940+
NE = ResNE;
8941+
8942+
SmallVector<SDValue, 8> LHSScalars;
8943+
SmallVector<SDValue, 8> RHSScalars;
8944+
ExtractVectorElements(N->getOperand(0), LHSScalars, 0, NE);
8945+
ExtractVectorElements(N->getOperand(1), RHSScalars, 0, NE);
8946+
8947+
SDVTList VTs = getVTList(ResEltVT, OvEltVT);
8948+
SmallVector<SDValue, 8> ResScalars;
8949+
SmallVector<SDValue, 8> OvScalars;
8950+
for (unsigned i = 0; i < NE; ++i) {
8951+
SDValue Res = getNode(Opcode, dl, VTs, LHSScalars[i], RHSScalars[i]);
8952+
ResScalars.push_back(Res);
8953+
OvScalars.push_back(SDValue(Res.getNode(), 1));
8954+
}
8955+
8956+
ResScalars.append(ResNE - NE, getUNDEF(ResEltVT));
8957+
OvScalars.append(ResNE - NE, getUNDEF(OvEltVT));
8958+
8959+
EVT NewResVT = EVT::getVectorVT(*getContext(), ResEltVT, ResNE);
8960+
EVT NewOvVT = EVT::getVectorVT(*getContext(), OvEltVT, ResNE);
8961+
return std::make_pair(getBuildVector(NewResVT, dl, ResScalars),
8962+
getBuildVector(NewOvVT, dl, OvScalars));
8963+
}
8964+
89218965
bool SelectionDAG::areNonVolatileConsecutiveLoads(LoadSDNode *LD,
89228966
LoadSDNode *Base,
89238967
unsigned Bytes,

Diff for: ‎llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

+24-16
Original file line numberDiff line numberDiff line change
@@ -5522,11 +5522,15 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
55225522
DAG.getConstant(Scale, dl, ShiftTy));
55235523
}
55245524

5525-
std::pair<SDValue, SDValue> TargetLowering::expandMULO(
5526-
SDNode *Node, SelectionDAG &DAG) const {
5525+
bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
5526+
SDValue &Overflow, SelectionDAG &DAG) const {
55275527
SDLoc dl(Node);
55285528
EVT VT = Node->getValueType(0);
5529-
EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits() * 2);
5529+
EVT WideVT = EVT::getIntegerVT(*DAG.getContext(), VT.getScalarSizeInBits() * 2);
5530+
if (VT.isVector())
5531+
WideVT = EVT::getVectorVT(*DAG.getContext(), WideVT,
5532+
VT.getVectorNumElements());
5533+
55305534
SDValue LHS = Node->getOperand(0);
55315535
SDValue RHS = Node->getOperand(1);
55325536
SDValue BottomHalf;
@@ -5546,11 +5550,15 @@ std::pair<SDValue, SDValue> TargetLowering::expandMULO(
55465550
LHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, LHS);
55475551
RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS);
55485552
SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS);
5549-
BottomHalf = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, VT, Mul,
5550-
DAG.getIntPtrConstant(0, dl));
5551-
TopHalf = DAG.getNode(ISD::EXTRACT_ELEMENT, dl, VT, Mul,
5552-
DAG.getIntPtrConstant(1, dl));
5553+
BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
5554+
SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl,
5555+
getShiftAmountTy(WideVT, DAG.getDataLayout()));
5556+
TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT,
5557+
DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt));
55535558
} else {
5559+
if (VT.isVector())
5560+
return false;
5561+
55545562
// We can fall back to a libcall with an illegal type for the MUL if we
55555563
// have a libcall big enough.
55565564
// Also, we can fall back to a division in some cases, but that's a big
@@ -5618,24 +5626,24 @@ std::pair<SDValue, SDValue> TargetLowering::expandMULO(
56185626
}
56195627

56205628
EVT SetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
5629+
Result = BottomHalf;
56215630
if (isSigned) {
56225631
SDValue ShiftAmt = DAG.getConstant(
5623-
VT.getSizeInBits() - 1, dl,
5632+
VT.getScalarSizeInBits() - 1, dl,
56245633
getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout()));
56255634
SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt);
5626-
TopHalf = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
5635+
Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
56275636
} else {
5628-
TopHalf = DAG.getSetCC(dl, SetCCVT, TopHalf,
5629-
DAG.getConstant(0, dl, VT), ISD::SETNE);
5637+
Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf,
5638+
DAG.getConstant(0, dl, VT), ISD::SETNE);
56305639
}
56315640

56325641
// Truncate the result if SetCC returns a larger type than needed.
56335642
EVT RType = Node->getValueType(1);
5634-
if (RType.getSizeInBits() < TopHalf.getValueSizeInBits())
5635-
TopHalf = DAG.getNode(ISD::TRUNCATE, dl, RType, TopHalf);
5643+
if (RType.getSizeInBits() < Overflow.getValueSizeInBits())
5644+
Overflow = DAG.getNode(ISD::TRUNCATE, dl, RType, Overflow);
56365645

5637-
assert(RType.getSizeInBits() == TopHalf.getValueSizeInBits() &&
5646+
assert(RType.getSizeInBits() == Overflow.getValueSizeInBits() &&
56385647
"Unexpected result type for S/UMULO legalization");
5639-
5640-
return std::make_pair(BottomHalf, TopHalf);
5648+
return true;
56415649
}

Diff for: ‎llvm/lib/Target/X86/X86ISelLowering.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -20184,6 +20184,8 @@ static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
2018420184
std::tie(Value, Overflow) = getX86XALUOOp(Cond, Op, DAG);
2018520185

2018620186
SDValue SetCC = getSETCC(Cond, Overflow, DL, DAG);
20187+
if (Op->getValueType(1) != MVT::i8)
20188+
SetCC = DAG.getNode(ISD::ZERO_EXTEND, DL, Op->getValueType(1), SetCC);
2018720189
return DAG.getNode(ISD::MERGE_VALUES, DL, Op->getVTList(), Value, SetCC);
2018820190
}
2018920191

Diff for: ‎llvm/test/CodeGen/AArch64/vec_umulo.ll

+401
Large diffs are not rendered by default.

Diff for: ‎llvm/test/CodeGen/X86/vec_smulo.ll

+2,770
Large diffs are not rendered by default.

Diff for: ‎llvm/test/CodeGen/X86/vec_umulo.ll

+2,633
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.