Index: llvm/include/llvm/IR/Constants.h =================================================================== --- llvm/include/llvm/IR/Constants.h +++ llvm/include/llvm/IR/Constants.h @@ -767,6 +767,9 @@ explicit ConstantDataVector(Type *ty, const char *Data) : ConstantDataSequential(ty, ConstantDataVectorVal, Data) {} + // Cache whether or not the constant is a splat. + mutable Optional IsSplat; + bool isSplatData() const; public: ConstantDataVector(const ConstantDataVector &) = delete; Index: llvm/lib/Analysis/ValueTracking.cpp =================================================================== --- llvm/lib/Analysis/ValueTracking.cpp +++ llvm/lib/Analysis/ValueTracking.cpp @@ -174,7 +174,13 @@ int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts); - + if (DemandedElts.isNullValue()) + return true; + // Simple case of a shuffle with zeroinitializer. + if (isa(Shuf->getMask())) { + DemandedLHS.setBit(0); + return true; + } for (int i = 0; i != NumMaskElts; ++i) { if (!DemandedElts[i]) continue; Index: llvm/lib/IR/ConstantFold.cpp =================================================================== --- llvm/lib/IR/ConstantFold.cpp +++ llvm/lib/IR/ConstantFold.cpp @@ -60,6 +60,11 @@ return nullptr; Type *DstEltTy = DstTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = CV->getSplatValue()) { + return ConstantVector::getSplat(DstTy->getVectorElementCount(), + ConstantExpr::getBitCast(Splat, DstEltTy)); + } SmallVector Result; Type *Ty = IntegerType::get(CV->getContext(), 32); @@ -577,9 +582,15 @@ if ((isa(V) || isa(V)) && DestTy->isVectorTy() && DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) { - SmallVector res; VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = V->getSplatValue()) { + return ConstantVector::getSplat( + DestTy->getVectorElementCount(), + ConstantExpr::getCast(opc, Splat, DstEltTy)); + } + SmallVector res; Type *Ty = IntegerType::get(V->getContext(), 32); for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) { Constant *C = @@ -878,6 +889,14 @@ // Don't break the bitcode reader hack. if (isa(Mask)) return nullptr; + // If the mask is all zeros this is a splat, no need to go through all + // elements. + if (isa(Mask) && !MaskEltCount.Scalable) { + Type *Ty = IntegerType::get(V1->getContext(), 32); + Constant *Elt = + ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0)); + return ConstantVector::getSplat(MaskEltCount, Elt); + } // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. VectorType *ValTy = cast(V1->getType()); @@ -993,10 +1012,15 @@ // compile-time. if (IsScalableVector) return nullptr; + Type *Ty = IntegerType::get(VTy->getContext(), 32); + // Fast path for splatted constants. + if (Constant *Splat = C->getSplatValue()) { + Constant *Elt = ConstantExpr::get(Opcode, Splat); + return ConstantVector::getSplat(VTy->getElementCount(), Elt); + } // Fold each element and create a vector constant from those constants. SmallVector Result; - Type *Ty = IntegerType::get(VTy->getContext(), 32); for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1357,6 +1381,16 @@ // compile-time. if (IsScalableVector) return nullptr; + // Fast path for splatted constants. + Constant *C1Splat = C1->getSplatValue(); + Constant *C2Splat = C2->getSplatValue(); + if (C2Splat && Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) + return UndefValue::get(VTy); + if (C1Splat && C2Splat) { + return ConstantVector::getSplat( + VTy->getVectorElementCount(), + ConstantExpr::get(Opcode, C1Splat, C2Splat)); + } // Fold each element and create a vector constant from those constants. SmallVector Result; @@ -1975,6 +2009,14 @@ // compile-time. if (C1->getType()->getVectorIsScalable()) return nullptr; + // Fast path for splatted constants. + Constant *C1Splat = C1->getSplatValue(); + Constant *C2Splat = C2->getSplatValue(); + if (C1Splat && C2Splat) { + return ConstantVector::getSplat( + C1->getType()->getVectorElementCount(), + ConstantExpr::getCompare(pred, C1Splat, C2Splat)); + } // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. Index: llvm/lib/IR/Constants.cpp =================================================================== --- llvm/lib/IR/Constants.cpp +++ llvm/lib/IR/Constants.cpp @@ -2891,7 +2891,7 @@ return Str.drop_back().find(0) == StringRef::npos; } -bool ConstantDataVector::isSplat() const { +bool ConstantDataVector::isSplatData() const { const char *Base = getRawDataValues().data(); // Compare elements 1+ to the 0'th element. @@ -2903,6 +2903,12 @@ return true; } +bool ConstantDataVector::isSplat() const { + if(!IsSplat.hasValue()) + IsSplat = isSplatData(); + return IsSplat.getValue(); +} + Constant *ConstantDataVector::getSplatValue() const { // If they're all the same, return the 0th one as a representative. return isSplat() ? getElementAsConstant(0) : nullptr; Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -1958,16 +1958,19 @@ assert(!Mask->getType()->getVectorElementCount().Scalable && "Length of scalable vectors unknown at compile time"); unsigned NumElts = Mask->getType()->getVectorNumElements(); - + if (isa(Mask)) { + Result.resize(NumElts, 0); + return; + } + Result.resize(NumElts); if (auto *CDS = dyn_cast(Mask)) { for (unsigned i = 0; i != NumElts; ++i) - Result.push_back(CDS->getElementAsInteger(i)); + Result[i] = CDS->getElementAsInteger(i); return; } for (unsigned i = 0; i != NumElts; ++i) { Constant *C = Mask->getAggregateElement(i); - Result.push_back(isa(C) ? -1 : - cast(C)->getZExtValue()); + Result[i] = isa(C) ? -1 : cast(C)->getZExtValue(); } } Index: llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1389,6 +1389,24 @@ "Expected shuffle operands to have same type"); unsigned OpWidth = Shuffle->getOperand(0)->getType()->getVectorNumElements(); + // Handle trivial case of a splat. Only check the first element of LHS + // operand. + if (isa(Shuffle->getMask()) && + DemandedElts.isAllOnesValue()) { + if (!isa(I->getOperand(1))) { + I->setOperand(1, UndefValue::get(I->getOperand(1)->getType())); + MadeChange = true; + } + APInt LeftDemanded(OpWidth, 1); + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + if (LHSUndefElts[0]) + UndefElts = EltMask; + else + UndefElts.clearAllBits(); + break; + } + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) {