diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5413,6 +5413,19 @@ } } + // Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)). + // (shl step_vector(C0), C1) -> (step_vector(C0 << C1)) + if ((Opcode == ISD::MUL || Opcode == ISD::SHL) && + Ops[0].getOpcode() == ISD::STEP_VECTOR) { + APInt RHSVal; + if (ISD::isConstantSplatVector(Ops[1].getNode(), RHSVal)) { + APInt NewStep = Opcode == ISD::MUL + ? Ops[0].getConstantOperandAPInt(0) * RHSVal + : Ops[0].getConstantOperandAPInt(0) << RHSVal; + return getStepVector(DL, VT, NewStep); + } + } + auto IsScalarOrSameVectorSize = [NumElts](const SDValue &Op) { return !Op.getValueType().isVector() || Op.getValueType().getVectorElementCount() == NumElts;