diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -766,7 +766,12 @@ friend class ConstantDataSequential; explicit ConstantDataVector(Type *ty, const char *Data) - : ConstantDataSequential(ty, ConstantDataVectorVal, Data) {} + : ConstantDataSequential(ty, ConstantDataVectorVal, Data), + IsSplatSet(false) {} + // Cache whether or not the constant is a splat. + mutable bool IsSplatSet : 1; + mutable bool IsSplat : 1; + bool isSplatData() const; public: ConstantDataVector(const ConstantDataVector &) = delete; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -174,7 +174,13 @@ int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts); - + if (DemandedElts.isNullValue()) + return true; + // Simple case of a shuffle with zeroinitializer. + if (isa(Shuf->getMask())) { + DemandedLHS.setBit(0); + return true; + } for (int i = 0; i != NumMaskElts; ++i) { if (!DemandedElts[i]) continue; diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -60,6 +60,11 @@ return nullptr; Type *DstEltTy = DstTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = CV->getSplatValue()) { + return ConstantVector::getSplat(DstTy->getVectorElementCount(), + ConstantExpr::getBitCast(Splat, DstEltTy)); + } SmallVector Result; Type *Ty = IntegerType::get(CV->getContext(), 32); @@ -577,9 +582,15 @@ if ((isa(V) || isa(V)) && DestTy->isVectorTy() && DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) { - SmallVector res; VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = V->getSplatValue()) { + return ConstantVector::getSplat( + DestTy->getVectorElementCount(), + ConstantExpr::getCast(opc, Splat, DstEltTy)); + } + SmallVector res; Type *Ty = IntegerType::get(V->getContext(), 32); for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) { Constant *C = @@ -878,6 +889,14 @@ // Don't break the bitcode reader hack. if (isa(Mask)) return nullptr; + // If the mask is all zeros this is a splat, no need to go through all + // elements. + if (isa(Mask) && !MaskEltCount.Scalable) { + Type *Ty = IntegerType::get(V1->getContext(), 32); + Constant *Elt = + ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0)); + return ConstantVector::getSplat(MaskEltCount, Elt); + } // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. VectorType *ValTy = cast(V1->getType()); @@ -993,10 +1012,15 @@ // compile-time. if (IsScalableVector) return nullptr; + Type *Ty = IntegerType::get(VTy->getContext(), 32); + // Fast path for splatted constants. + if (Constant *Splat = C->getSplatValue()) { + Constant *Elt = ConstantExpr::get(Opcode, Splat); + return ConstantVector::getSplat(VTy->getElementCount(), Elt); + } // Fold each element and create a vector constant from those constants. SmallVector Result; - Type *Ty = IntegerType::get(VTy->getContext(), 32); for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1357,6 +1381,16 @@ // compile-time. if (IsScalableVector) return nullptr; + // Fast path for splatted constants. + if (Constant *C2Splat = C2->getSplatValue()) { + if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) + return UndefValue::get(VTy); + if (Constant *C1Splat = C1->getSplatValue()) { + return ConstantVector::getSplat( + VTy->getVectorElementCount(), + ConstantExpr::get(Opcode, C1Splat, C2Splat)); + } + } // Fold each element and create a vector constant from those constants. SmallVector Result; @@ -1975,6 +2009,12 @@ // compile-time. if (C1->getType()->getVectorIsScalable()) return nullptr; + // Fast path for splatted constants. + if (Constant *C1Splat = C1->getSplatValue()) + if (Constant *C2Splat = C2->getSplatValue()) + return ConstantVector::getSplat( + C1->getType()->getVectorElementCount(), + ConstantExpr::getCompare(pred, C1Splat, C2Splat)); // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2891,7 +2891,7 @@ return Str.drop_back().find(0) == StringRef::npos; } -bool ConstantDataVector::isSplat() const { +bool ConstantDataVector::isSplatData() const { const char *Base = getRawDataValues().data(); // Compare elements 1+ to the 0'th element. @@ -2903,6 +2903,14 @@ return true; } +bool ConstantDataVector::isSplat() const { + if (!IsSplatSet) { + IsSplatSet = true; + IsSplat = isSplatData(); + } + return IsSplat; +} + Constant *ConstantDataVector::getSplatValue() const { // If they're all the same, return the 0th one as a representative. return isSplat() ? getElementAsConstant(0) : nullptr; diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -1958,7 +1958,11 @@ assert(!Mask->getType()->getVectorElementCount().Scalable && "Length of scalable vectors unknown at compile time"); unsigned NumElts = Mask->getType()->getVectorNumElements(); - + if (isa(Mask)) { + Result.resize(NumElts, 0); + return; + } + Result.reserve(NumElts); if (auto *CDS = dyn_cast(Mask)) { for (unsigned i = 0; i != NumElts; ++i) Result.push_back(CDS->getElementAsInteger(i)); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1387,6 +1387,24 @@ "Expected shuffle operands to have same type"); unsigned OpWidth = Shuffle->getOperand(0)->getType()->getVectorNumElements(); + // Handle trivial case of a splat. Only check the first element of LHS + // operand. + if (isa(Shuffle->getMask()) && + DemandedElts.isAllOnesValue()) { + if (!isa(I->getOperand(1))) { + I->setOperand(1, UndefValue::get(I->getOperand(1)->getType())); + MadeChange = true; + } + APInt LeftDemanded(OpWidth, 1); + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + if (LHSUndefElts[0]) + UndefElts = EltMask; + else + UndefElts.clearAllBits(); + break; + } + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) {