@@ -26962,18 +26962,223 @@ static SDValue combineCMov(SDNode *N, SelectionDAG &DAG,
26962
26962
return SDValue();
26963
26963
}
26964
26964
26965
+ /// Different mul shrinking modes.
26966
+ enum ShrinkMode { MULS8, MULU8, MULS16, MULU16 };
26967
+
26968
+ static bool canReduceVMulWidth(SDNode *N, SelectionDAG &DAG, ShrinkMode &Mode) {
26969
+ EVT VT = N->getOperand(0).getValueType();
26970
+ if (VT.getScalarSizeInBits() != 32)
26971
+ return false;
26972
+
26973
+ assert(N->getNumOperands() == 2 && "NumOperands of Mul are 2");
26974
+ unsigned SignBits[2] = {1, 1};
26975
+ bool IsPositive[2] = {false, false};
26976
+ for (unsigned i = 0; i < 2; i++) {
26977
+ SDValue Opd = N->getOperand(i);
26978
+
26979
+ // DAG.ComputeNumSignBits return 1 for ISD::ANY_EXTEND, so we need to
26980
+ // compute signbits for it separately.
26981
+ if (Opd.getOpcode() == ISD::ANY_EXTEND) {
26982
+ // For anyextend, it is safe to assume an appropriate number of leading
26983
+ // sign/zero bits.
26984
+ if (Opd.getOperand(0).getValueType().getVectorElementType() == MVT::i8)
26985
+ SignBits[i] = 25;
26986
+ else if (Opd.getOperand(0).getValueType().getVectorElementType() ==
26987
+ MVT::i16)
26988
+ SignBits[i] = 17;
26989
+ else
26990
+ return false;
26991
+ IsPositive[i] = true;
26992
+ } else if (Opd.getOpcode() == ISD::BUILD_VECTOR) {
26993
+ // All the operands of BUILD_VECTOR need to be int constant.
26994
+ // Find the smallest value range which all the operands belong to.
26995
+ SignBits[i] = 32;
26996
+ IsPositive[i] = true;
26997
+ for (const SDValue &SubOp : Opd.getNode()->op_values()) {
26998
+ if (SubOp.isUndef())
26999
+ continue;
27000
+ auto *CN = dyn_cast<ConstantSDNode>(SubOp);
27001
+ if (!CN)
27002
+ return false;
27003
+ APInt IntVal = CN->getAPIntValue();
27004
+ if (IntVal.isNegative())
27005
+ IsPositive[i] = false;
27006
+ SignBits[i] = std::min(SignBits[i], IntVal.getNumSignBits());
27007
+ }
27008
+ } else {
27009
+ SignBits[i] = DAG.ComputeNumSignBits(Opd);
27010
+ if (Opd.getOpcode() == ISD::ZERO_EXTEND)
27011
+ IsPositive[i] = true;
27012
+ }
27013
+ }
27014
+
27015
+ bool AllPositive = IsPositive[0] && IsPositive[1];
27016
+ unsigned MinSignBits = std::min(SignBits[0], SignBits[1]);
27017
+ // When ranges are from -128 ~ 127, use MULS8 mode.
27018
+ if (MinSignBits >= 25)
27019
+ Mode = MULS8;
27020
+ // When ranges are from 0 ~ 255, use MULU8 mode.
27021
+ else if (AllPositive && MinSignBits >= 24)
27022
+ Mode = MULU8;
27023
+ // When ranges are from -32768 ~ 32767, use MULS16 mode.
27024
+ else if (MinSignBits >= 17)
27025
+ Mode = MULS16;
27026
+ // When ranges are from 0 ~ 65535, use MULU16 mode.
27027
+ else if (AllPositive && MinSignBits >= 16)
27028
+ Mode = MULU16;
27029
+ else
27030
+ return false;
27031
+ return true;
27032
+ }
27033
+
27034
+ /// When the operands of vector mul are extended from smaller size values,
27035
+ /// like i8 and i16, the type of mul may be shrinked to generate more
27036
+ /// efficient code. Two typical patterns are handled:
27037
+ /// Pattern1:
27038
+ /// %2 = sext/zext <N x i8> %1 to <N x i32>
27039
+ /// %4 = sext/zext <N x i8> %3 to <N x i32>
27040
+ // or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
27041
+ /// %5 = mul <N x i32> %2, %4
27042
+ ///
27043
+ /// Pattern2:
27044
+ /// %2 = zext/sext <N x i16> %1 to <N x i32>
27045
+ /// %4 = zext/sext <N x i16> %3 to <N x i32>
27046
+ /// or %4 = build_vector <N x i32> %C1, ..., %CN (%C1..%CN are constants)
27047
+ /// %5 = mul <N x i32> %2, %4
27048
+ ///
27049
+ /// There are four mul shrinking modes:
27050
+ /// If %2 == sext32(trunc8(%2)), i.e., the scalar value range of %2 is
27051
+ /// -128 to 128, and the scalar value range of %4 is also -128 to 128,
27052
+ /// generate pmullw+sext32 for it (MULS8 mode).
27053
+ /// If %2 == zext32(trunc8(%2)), i.e., the scalar value range of %2 is
27054
+ /// 0 to 255, and the scalar value range of %4 is also 0 to 255,
27055
+ /// generate pmullw+zext32 for it (MULU8 mode).
27056
+ /// If %2 == sext32(trunc16(%2)), i.e., the scalar value range of %2 is
27057
+ /// -32768 to 32767, and the scalar value range of %4 is also -32768 to 32767,
27058
+ /// generate pmullw+pmulhw for it (MULS16 mode).
27059
+ /// If %2 == zext32(trunc16(%2)), i.e., the scalar value range of %2 is
27060
+ /// 0 to 65535, and the scalar value range of %4 is also 0 to 65535,
27061
+ /// generate pmullw+pmulhuw for it (MULU16 mode).
27062
+ static SDValue reduceVMULWidth(SDNode *N, SelectionDAG &DAG,
27063
+ const X86Subtarget &Subtarget) {
27064
+ // pmulld is supported since SSE41. It is better to use pmulld
27065
+ // instead of pmullw+pmulhw.
27066
+ if (Subtarget.hasSSE41())
27067
+ return SDValue();
27068
+
27069
+ ShrinkMode Mode;
27070
+ if (!canReduceVMulWidth(N, DAG, Mode))
27071
+ return SDValue();
27072
+
27073
+ SDLoc DL(N);
27074
+ SDValue N0 = N->getOperand(0);
27075
+ SDValue N1 = N->getOperand(1);
27076
+ EVT VT = N->getOperand(0).getValueType();
27077
+ unsigned RegSize = 128;
27078
+ MVT OpsVT = MVT::getVectorVT(MVT::i16, RegSize / 16);
27079
+ EVT ReducedVT =
27080
+ EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorNumElements());
27081
+ // Shrink the operands of mul.
27082
+ SDValue NewN0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N0);
27083
+ SDValue NewN1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, N1);
27084
+
27085
+ if (VT.getVectorNumElements() >= OpsVT.getVectorNumElements()) {
27086
+ // Generate the lower part of mul: pmullw. For MULU8/MULS8, only the
27087
+ // lower part is needed.
27088
+ SDValue MulLo = DAG.getNode(ISD::MUL, DL, ReducedVT, NewN0, NewN1);
27089
+ if (Mode == MULU8 || Mode == MULS8) {
27090
+ return DAG.getNode((Mode == MULU8) ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND,
27091
+ DL, VT, MulLo);
27092
+ } else {
27093
+ MVT ResVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() / 2);
27094
+ // Generate the higher part of mul: pmulhw/pmulhuw. For MULU16/MULS16,
27095
+ // the higher part is also needed.
27096
+ SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL,
27097
+ ReducedVT, NewN0, NewN1);
27098
+
27099
+ // Repack the lower part and higher part result of mul into a wider
27100
+ // result.
27101
+ // Generate shuffle functioning as punpcklwd.
27102
+ SmallVector<int, 16> ShuffleMask(VT.getVectorNumElements());
27103
+ for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) {
27104
+ ShuffleMask[2 * i] = i;
27105
+ ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements();
27106
+ }
27107
+ SDValue ResLo =
27108
+ DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]);
27109
+ ResLo = DAG.getNode(ISD::BITCAST, DL, ResVT, ResLo);
27110
+ // Generate shuffle functioning as punpckhwd.
27111
+ for (unsigned i = 0; i < VT.getVectorNumElements() / 2; i++) {
27112
+ ShuffleMask[2 * i] = i + VT.getVectorNumElements() / 2;
27113
+ ShuffleMask[2 * i + 1] = i + VT.getVectorNumElements() * 3 / 2;
27114
+ }
27115
+ SDValue ResHi =
27116
+ DAG.getVectorShuffle(ReducedVT, DL, MulLo, MulHi, &ShuffleMask[0]);
27117
+ ResHi = DAG.getNode(ISD::BITCAST, DL, ResVT, ResHi);
27118
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, ResLo, ResHi);
27119
+ }
27120
+ } else {
27121
+ // When VT.getVectorNumElements() < OpsVT.getVectorNumElements(), we want
27122
+ // to legalize the mul explicitly because implicit legalization for type
27123
+ // <4 x i16> to <4 x i32> sometimes involves unnecessary unpack
27124
+ // instructions which will not exist when we explicitly legalize it by
27125
+ // extending <4 x i16> to <8 x i16> (concatenating the <4 x i16> val with
27126
+ // <4 x i16> undef).
27127
+ //
27128
+ // Legalize the operands of mul.
27129
+ SmallVector<SDValue, 16> Ops(RegSize / ReducedVT.getSizeInBits(),
27130
+ DAG.getUNDEF(ReducedVT));
27131
+ Ops[0] = NewN0;
27132
+ NewN0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops);
27133
+ Ops[0] = NewN1;
27134
+ NewN1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, OpsVT, Ops);
27135
+
27136
+ if (Mode == MULU8 || Mode == MULS8) {
27137
+ // Generate lower part of mul: pmullw. For MULU8/MULS8, only the lower
27138
+ // part is needed.
27139
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1);
27140
+
27141
+ // convert the type of mul result to VT.
27142
+ MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
27143
+ SDValue Res = DAG.getNode(Mode == MULU8 ? ISD::ZERO_EXTEND_VECTOR_INREG
27144
+ : ISD::SIGN_EXTEND_VECTOR_INREG,
27145
+ DL, ResVT, Mul);
27146
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
27147
+ DAG.getIntPtrConstant(0, DL));
27148
+ } else {
27149
+ // Generate the lower and higher part of mul: pmulhw/pmulhuw. For
27150
+ // MULU16/MULS16, both parts are needed.
27151
+ SDValue MulLo = DAG.getNode(ISD::MUL, DL, OpsVT, NewN0, NewN1);
27152
+ SDValue MulHi = DAG.getNode(Mode == MULS16 ? ISD::MULHS : ISD::MULHU, DL,
27153
+ OpsVT, NewN0, NewN1);
27154
+
27155
+ // Repack the lower part and higher part result of mul into a wider
27156
+ // result. Make sure the type of mul result is VT.
27157
+ MVT ResVT = MVT::getVectorVT(MVT::i32, RegSize / 32);
27158
+ SDValue Res = DAG.getNode(X86ISD::UNPCKL, DL, OpsVT, MulLo, MulHi);
27159
+ Res = DAG.getNode(ISD::BITCAST, DL, ResVT, Res);
27160
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
27161
+ DAG.getIntPtrConstant(0, DL));
27162
+ }
27163
+ }
27164
+ }
27165
+
26965
27166
/// Optimize a single multiply with constant into two operations in order to
26966
27167
/// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA.
26967
27168
static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
26968
- TargetLowering::DAGCombinerInfo &DCI) {
27169
+ TargetLowering::DAGCombinerInfo &DCI,
27170
+ const X86Subtarget &Subtarget) {
27171
+ EVT VT = N->getValueType(0);
27172
+ if (DCI.isBeforeLegalize() && VT.isVector())
27173
+ return reduceVMULWidth(N, DAG, Subtarget);
27174
+
26969
27175
// An imul is usually smaller than the alternative sequence.
26970
27176
if (DAG.getMachineFunction().getFunction()->optForMinSize())
26971
27177
return SDValue();
26972
27178
26973
27179
if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
26974
27180
return SDValue();
26975
27181
26976
- EVT VT = N->getValueType(0);
26977
27182
if (VT != MVT::i64 && VT != MVT::i32)
26978
27183
return SDValue();
26979
27184
@@ -30268,7 +30473,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
30268
30473
case ISD::ADD: return combineAdd(N, DAG, Subtarget);
30269
30474
case ISD::SUB: return combineSub(N, DAG, Subtarget);
30270
30475
case X86ISD::ADC: return combineADC(N, DAG, DCI);
30271
- case ISD::MUL: return combineMul(N, DAG, DCI);
30476
+ case ISD::MUL: return combineMul(N, DAG, DCI, Subtarget );
30272
30477
case ISD::SHL:
30273
30478
case ISD::SRA:
30274
30479
case ISD::SRL: return combineShift(N, DAG, DCI, Subtarget);
0 commit comments