Skip to content

Commit b799a62

Browse files
committedJun 14, 2016
[X86] Reduce the width of multiplification when its operands are extended from i8 or i16
For <N x i32> type mul, pmuludq will be used for targets without SSE41, which often introduces many extra pack and unpack instructions in vectorized loop body because pmuludq generates <N/2 x i64> type value. However when the operands of <N x i32> mul are extended from smaller size values like i8 and i16, the type of mul may be shrunk to use pmullw + pmulhw/pmulhuw instead of pmuludq, which generates better code. For targets with SSE41, pmulld is supported so no shrinking is needed. Differential Revision: http://reviews.llvm.org/D20931 llvm-svn: 272694
1 parent 07c229c commit b799a62

File tree

3 files changed

+1074
-3
lines changed

3 files changed

+1074
-3
lines changed
 

‎llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
670670
case ISD::ADD:
671671
case ISD::SUB:
672672
case ISD::MUL:
673+
case ISD::MULHS:
674+
case ISD::MULHU:
673675
case ISD::FADD:
674676
case ISD::FSUB:
675677
case ISD::FMUL:

‎llvm/lib/Target/X86/X86ISelLowering.cpp

+208-3
Original file line numberDiff line numberDiff line change
@@ -26962,18 +26962,223 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
2696226962
return SDValue();
2696326963
}
2696426964

26965+
/// Different mul shrinking modes.
26966+
enum ShrinkMode { MULS8, MULU8, MULS16, MULU16 };
26967+
26968+
static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) {
26969+
EVT VT = N->getOperand(0).getValueType();
26970+
if (VT.getScalarSizeInBits() != 32)
26971+
return false;
26972+
26973+
assert(N->getNumOperands() == 2 && "NumOperands of Mul are 2");
26974+
unsigned SignBits[2] = {1, 1};
26975+
bool IsPositive[2] = {false, false};
26976+
for (unsigned i = 0; i < 2; i++) {
26977+
SDValue Opd = N->getOperand(i);
26978+
26979+
// DAG.ComputeNumSignBits return 1 for ISD::ANY_EXTEND, so we need to
26980+
// compute signbits for it separately.
26981+
if (Opd.getOpcode() == ISD::ANY_EXTEND) {
26982+
// For anyextend, it is safe to assume an appropriate number of leading
26983+
// sign/zero bits.
26984+
if (Opd.getOperand(0).getValueType().getVectorElementType() == MVT::i8)
26985+
SignBits[i] = 25;
26986+
else if (Opd.getOperand(0).getValueType().getVectorElementType() ==
26987+
MVT::i16)
26988+
SignBits[i] = 17;
26989+
else
26990+
return false;
26991+
IsPositive[i] = true;
26992+
} else if (Opd.getOpcode() == ISD::BUILD_VECTOR) {
26993+
// All the operands of BUILD_VECTOR need to be int constant.
26994+
// Find the smallest value range which all the operands belong to.
26995+
SignBits[i] = 32;
26996+
IsPositive[i] = true;
26997+
for (const SDValue &SubOp : Opd.getNode()->op_values()) {
26998+
if (SubOp.isUndef())
26999+
continue;
27000+
auto *CN = dyn_cast<ConstantSDNode>(SubOp);
27001+
if (!CN)
27002+
return false;
27003+
APInt IntVal = CN->getAPIntValue();
27004+
if (IntVal.isNegative())
27005+
IsPositive[i] = false;
27006+
SignBits[i] = std::min(SignBits[i], IntVal.getNumSignBits());
27007+
}
27008+
} else {
27009+
SignBits[i] = DAG.ComputeNumSignBits(Opd);
27010+
if (Opd.getOpcode() == ISD::ZERO_EXTEND)
27011+
IsPositive[i] = true;
27012+
}
27013+
}
27014+
27015+
bool AllPositive = IsPositive[0] && IsPositive[1];
27016+
unsigned MinSignBits = std::min(SignBits[0], SignBits[1]);
27017+
// When ranges are from -128 ~ 127, use MULS8 mode.
27018+
if (MinSignBits >= 25)
27019+
Mode = MULS8;
27020+
// When ranges are from 0 ~ 255, use MULU8 mode.
27021+
else if (AllPositive && MinSignBits >= 24)
27022+
Mode = MULU8;
27023+
// When ranges are from -32768 ~ 32767, use MULS16 mode.
27024+
else if (MinSignBits >= 17)
27025+
Mode = MULS16;
27026+
// When ranges are from 0 ~ 65535, use MULU16 mode.
27027+
else if (AllPositive && MinSignBits >= 16)
27028+
Mode = MULU16;
27029+
else
27030+
return false;
27031+
return true;
27032+
}
27033+
27034+
/// When the operands of vector mul are extended from smaller size values,
27035+
/// like i8 and i16, the type of mul may be shrinked to generate more
27036+
/// efficient code. Two typical patterns are handled:
27037+
/// Pattern1:
27038+
/// %2 = sext/zext <N x i8> %1 to <N x i32>
27039+
/// %4 = sext/zext <N x i8> %3 to <N x i32>
27040+
// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
27041+
/// %5 = mul <N x i32> %2, %4
27042+
///
27043+
/// Pattern2:
27044+
/// %2 = zext/sext <N x i16> %1 to <N x i32>
27045+
/// %4 = zext/sext <N x i16> %3 to <N x i32>
27046+
/// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
27047+
/// %5 = mul <N x i32> %2, %4
27048+
///
27049+
/// There are four mul shrinking modes:
27050+
/// If %2 == sext32(trunc8(%2)), i.e., the scalar value range of %2 is
27051+
/// -128 to 128, and the scalar value range of %4 is also -128 to 128,
27052+
/// generate pmullw+sext32 for it (MULS8 mode).
27053+
/// If %2 == zext32(trunc8(%2)), i.e., the scalar value range of %2 is
27054+
/// 0 to 255, and the scalar value range of %4 is also 0 to 255,
27055+
/// generate pmullw+zext32 for it (MULU8 mode).
27056+
/// If %2 == sext32(trunc16(%2)), i.e., the scalar value range of %2 is
27057+
/// -32768 to 32767, and the scalar value range of %4 is also -32768 to 32767,
27058+
/// generate pmullw+pmulhw for it (MULS16 mode).
27059+
/// If %2 == zext32(trunc16(%2)), i.e., the scalar value range of %2 is
27060+
/// 0 to 65535, and the scalar value range of %4 is also 0 to 65535,
27061+
/// generate pmullw+pmulhuw for it (MULU16 mode).
27062+
static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG,
27063+
const X86Subtarget &Subtarget) {
27064+
// pmulld is supported since SSE41. It is better to use pmulld
27065+
// instead of pmullw+pmulhw.
27066+
if (Subtarget.hasSSE41())
27067+
return SDValue();
27068+
27069+
ShrinkMode Mode;
27070+
if (!canReduceVMulWidth(N, DAG, Mode))
27071+
return SDValue();
27072+
27073+
SDLoc DL(N);
27074+
SDValue N0 = N->getOperand(0);
27075+
SDValue N1 = N->getOperand(1);
27076+
EVT VT = N->getOperand(0).getValueType();
27077+
unsigned RegSize = 128;
27078+
MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16);
27079+
EVT ReducedVT =
27080+
EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements());
27081+
// Shrink the operands of mul.
27082+
SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0);
27083+
SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1);
27084+
27085+
if (VT.getVectorNumElements() >= OpsVT.getVectorNumElements()) {
27086+
// Generate the lower part of mul: pmullw. For MULU8/MULS8, only the
27087+
// lower part is needed.
27088+
SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1);
27089+
if (Mode == MULU8 || Mode == MULS8) {
27090+
return DAG.getNode((Mode == MULU8) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
27091+
DL, VT, MulLo);
27092+
} else {
27093+
MVT ResVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2);
27094+
// Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16,
27095+
// the higher part is also needed.
27096+
SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL,
27097+
ReducedVT, NewN0, NewN1);
27098+
27099+
// Repack the lower part and higher part result of mul into a wider
27100+
// result.
27101+
// Generate shuffle functioning as punpcklwd.
27102+
SmallVector<int, 16> ShuffleMask(VT.getVectorNumElements());
27103+
for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) {
27104+
ShuffleMask[2 * i] = i;
27105+
ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements();
27106+
}
27107+
SDValue ResLo =
27108+
DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]);
27109+
ResLo = DAG.getNode(ISD::BITCAST, DL, ResVT, ResLo);
27110+
// Generate shuffle functioning as punpckhwd.
27111+
for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) {
27112+
ShuffleMask[2 * i] = i + VT.getVectorNumElements() / 2;
27113+
ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements() * 3 / 2;
27114+
}
27115+
SDValue ResHi =
27116+
DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]);
27117+
ResHi = DAG.getNode(ISD::BITCAST, DL, ResVT, ResHi);
27118+
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi);
27119+
}
27120+
} else {
27121+
// When VT.getVectorNumElements() < OpsVT.getVectorNumElements(), we want
27122+
// to legalize the mul explicitly because implicit legalization for type
27123+
// <4 x i16> to <4 x i32> sometimes involves unnecessary unpack
27124+
// instructions which will not exist when we explicitly legalize it by
27125+
// extending <4 x i16> to <8 x i16> (concatenating the <4 x i16> val with
27126+
// <4 x i16> undef).
27127+
//
27128+
// Legalize the operands of mul.
27129+
SmallVector<SDValue, 16> Ops(RegSize / ReducedVT.getSizeInBits(),
27130+
DAG.getUNDEF(ReducedVT));
27131+
Ops[0] = NewN0;
27132+
NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops);
27133+
Ops[0] = NewN1;
27134+
NewN1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops);
27135+
27136+
if (Mode == MULU8 || Mode == MULS8) {
27137+
// Generate lower part of mul: pmullw. For MULU8/MULS8, only the lower
27138+
// part is needed.
27139+
SDValue Mul = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1);
27140+
27141+
// convert the type of mul result to VT.
27142+
MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
27143+
SDValue Res = DAG.getNode(Mode == MULU8 ? ISD::ZERO_EXTEND_VECTOR_INREG
27144+
: ISD::SIGN_EXTEND_VECTOR_INREG,
27145+
DL, ResVT, Mul);
27146+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
27147+
DAG.getIntPtrConstant(0, DL));
27148+
} else {
27149+
// Generate the lower and higher part of mul: pmulhw/pmulhuw. For
27150+
// MULU16/MULS16, both parts are needed.
27151+
SDValue MulLo = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1);
27152+
SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL,
27153+
OpsVT, NewN0, NewN1);
27154+
27155+
// Repack the lower part and higher part result of mul into a wider
27156+
// result. Make sure the type of mul result is VT.
27157+
MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
27158+
SDValue Res = DAG.getNode(X86ISD::UNPCKL, DL, OpsVT, MulLo, MulHi);
27159+
Res = DAG.getNode(ISD::BITCAST, DL, ResVT, Res);
27160+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
27161+
DAG.getIntPtrConstant(0, DL));
27162+
}
27163+
}
27164+
}
27165+
2696527166
/// Optimize a single multiply with constant into two operations in order to
2696627167
/// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA.
2696727168
static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
26968-
TargetLowering::DAGCombinerInfo &DCI) {
27169+
TargetLowering::DAGCombinerInfo &DCI,
27170+
const X86Subtarget &Subtarget) {
27171+
EVT VT = N->getValueType(0);
27172+
if (DCI.isBeforeLegalize() && VT.isVector())
27173+
return reduceVMULWidth(N, DAG, Subtarget);
27174+
2696927175
// An imul is usually smaller than the alternative sequence.
2697027176
if (DAG.getMachineFunction().getFunction()->optForMinSize())
2697127177
return SDValue();
2697227178

2697327179
if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
2697427180
return SDValue();
2697527181

26976-
EVT VT = N->getValueType(0);
2697727182
if (VT != MVT::i64 && VT != MVT::i32)
2697827183
return SDValue();
2697927184

@@ -30268,7 +30473,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
3026830473
case ISD::ADD: return combineAdd(N, DAG, Subtarget);
3026930474
case ISD::SUB: return combineSub(N, DAG, Subtarget);
3027030475
case X86ISD::ADC: return combineADC(N, DAG, DCI);
30271-
case ISD::MUL: return combineMul(N, DAG, DCI);
30476+
case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget);
3027230477
case ISD::SHL:
3027330478
case ISD::SRA:
3027430479
case ISD::SRL: return combineShift(N, DAG, DCI, Subtarget);

0 commit comments

Comments
 (0)
Please sign in to comment.