diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -3946,15 +3946,6 @@ return DAG.getVScale(SDLoc(N), VT, C0 * C1); } - // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)). - APInt MulVal; - if (N0.getOpcode() == ISD::STEP_VECTOR) - if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) { - const APInt &C0 = N0.getConstantOperandAPInt(0); - APInt NewStep = C0 * MulVal; - return DAG.getStepVector(SDLoc(N), VT, NewStep); - } - // Fold ((mul x, 0/undef) -> 0, // (mul x, 1) -> x) -> x) // -> and(x, mask) @@ -8666,17 +8657,6 @@ return DAG.getVScale(SDLoc(N), VT, C0 << C1); } - // Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)). - APInt ShlVal; - if (N0.getOpcode() == ISD::STEP_VECTOR) - if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) { - const APInt &C0 = N0.getConstantOperandAPInt(0); - if (ShlVal.ult(C0.getBitWidth())) { - APInt NewStep = C0 << ShlVal; - return DAG.getStepVector(SDLoc(N), VT, NewStep); - } - } - return SDValue(); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5413,6 +5413,19 @@ } } + // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)). + // (shl step_vector(C0), C1) -> (step_vector(C0 << C1)) + if ((Opcode == ISD::MUL || Opcode == ISD::SHL) && + Ops[0].getOpcode() == ISD::STEP_VECTOR) { + APInt RHSVal; + if (ISD::isConstantSplatVector(Ops[1].getNode(), RHSVal)) { + APInt NewStep = Opcode == ISD::MUL + ? Ops[0].getConstantOperandAPInt(0) * RHSVal + : Ops[0].getConstantOperandAPInt(0) << RHSVal; + return getStepVector(DL, VT, NewStep); + } + } + auto IsScalarOrSameVectorSize = [NumElts](const SDValue &Op) { return !Op.getValueType().isVector() || Op.getValueType().getVectorElementCount() == NumElts;