diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -766,7 +766,11 @@ friend class ConstantDataSequential; explicit ConstantDataVector(Type *ty, const char *Data) - : ConstantDataSequential(ty, ConstantDataVectorVal, Data) {} + : ConstantDataSequential(ty, ConstantDataVectorVal, Data) { + IsSplat = isSplatData(ty, Data); + } + bool IsSplat = false; + bool isSplatData(Type *ty, const char *Data) const; public: ConstantDataVector(const ConstantDataVector &) = delete; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -174,7 +174,11 @@ int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts); - + // Simple case of a shuffle with zeroinitializer. + if (isa(Shuf->getMask())) { + DemandedLHS.setBit(0); + return true; + } for (int i = 0; i != NumMaskElts; ++i) { if (!DemandedElts[i]) continue; diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -61,6 +61,11 @@ Type *DstEltTy = DstTy->getElementType(); + if (Constant *Splat = CV->getSplatValue()) { + return ConstantVector::getSplat(DstTy->getVectorElementCount(), + ConstantExpr::getBitCast(Splat, DstEltTy)); + } + SmallVector Result; Type *Ty = IntegerType::get(CV->getContext(), 32); for (unsigned i = 0; i != NumElts; ++i) { @@ -577,9 +582,14 @@ if ((isa(V) || isa(V)) && DestTy->isVectorTy() && DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) { - SmallVector res; VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); + if (Constant *Splat = V->getSplatValue()) { + return ConstantVector::getSplat( + DestTy->getVectorElementCount(), + ConstantExpr::getCast(opc, Splat, DstEltTy)); + } + SmallVector res; Type *Ty = IntegerType::get(V->getContext(), 32); for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) { Constant *C = @@ -878,6 +888,12 @@ // Don't break the bitcode reader hack. if (isa(Mask)) return nullptr; + if (isa(Mask) && !MaskEltCount.Scalable) { + Type *Ty = IntegerType::get(V1->getContext(), 32); + Constant *Elt = + ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0)); + return ConstantVector::getSplat(MaskEltCount, Elt); + } // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. VectorType *ValTy = cast(V1->getType()); @@ -993,10 +1009,15 @@ // compile-time. if (IsScalableVector) return nullptr; + Type *Ty = IntegerType::get(VTy->getContext(), 32); + + if (Constant *Splat = C->getSplatValue()) { + Constant *Elt = ConstantExpr::get(Opcode, Splat); + return ConstantVector::getSplat(VTy->getElementCount(), Elt); + } // Fold each element and create a vector constant from those constants. SmallVector Result; - Type *Ty = IntegerType::get(VTy->getContext(), 32); for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1357,6 +1378,16 @@ // compile-time. if (IsScalableVector) return nullptr; + // Fast path for splatted constants. + Constant *C1Splat = C1->getSplatValue(); + Constant *C2Splat = C2->getSplatValue(); + if (C1Splat && C2Splat) { + if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) + return UndefValue::get(VTy); + return ConstantVector::getSplat( + VTy->getVectorElementCount(), + ConstantExpr::get(Opcode, C1Splat, C2Splat)); + } // Fold each element and create a vector constant from those constants. SmallVector Result; @@ -1975,6 +2006,14 @@ // compile-time. if (C1->getType()->getVectorIsScalable()) return nullptr; + // Fast path for splatted constants. + Constant *C1Splat = C1->getSplatValue(); + Constant *C2Splat = C2->getSplatValue(); + if (C1Splat && C2Splat) { + return ConstantVector::getSplat( + C1->getType()->getVectorElementCount(), + ConstantExpr::getCompare(pred, C1Splat, C2Splat)); + } // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2891,18 +2891,20 @@ return Str.drop_back().find(0) == StringRef::npos; } -bool ConstantDataVector::isSplat() const { - const char *Base = getRawDataValues().data(); +bool ConstantDataVector::isSplatData(Type *Ty, const char *Data) const { + const char *Base = Data; // Compare elements 1+ to the 0'th element. - unsigned EltSize = getElementByteSize(); - for (unsigned i = 1, e = getNumElements(); i != e; ++i) + unsigned EltSize = Ty->getVectorElementType()->getPrimitiveSizeInBits() / 8; + for (unsigned i = 1, e = Ty->getVectorNumElements(); i != e; ++i) if (memcmp(Base, Base+i*EltSize, EltSize)) return false; return true; } +bool ConstantDataVector::isSplat() const { return IsSplat; } + Constant *ConstantDataVector::getSplatValue() const { // If they're all the same, return the 0th one as a representative. return isSplat() ? getElementAsConstant(0) : nullptr; diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -1958,16 +1958,19 @@ assert(!Mask->getType()->getVectorElementCount().Scalable && "Length of scalable vectors unknown at compile time"); unsigned NumElts = Mask->getType()->getVectorNumElements(); - + if (isa(Mask)) { + Result.resize(NumElts, 0); + return; + } + Result.resize(NumElts); if (auto *CDS = dyn_cast(Mask)) { for (unsigned i = 0; i != NumElts; ++i) - Result.push_back(CDS->getElementAsInteger(i)); + Result[i] = CDS->getElementAsInteger(i); return; } for (unsigned i = 0; i != NumElts; ++i) { Constant *C = Mask->getAggregateElement(i); - Result.push_back(isa(C) ? -1 : - cast(C)->getZExtValue()); + Result[i] = isa(C) ? -1 : cast(C)->getZExtValue(); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1389,6 +1389,24 @@ "Expected shuffle operands to have same type"); unsigned OpWidth = Shuffle->getOperand(0)->getType()->getVectorNumElements(); + // Handle trivial case of a splat. Only check the first element of LHS + // operand. + if (isa(Shuffle->getMask()) && + DemandedElts.isAllOnesValue()) { + if (!isa(I->getOperand(1))) { + I->setOperand(1, UndefValue::get(I->getOperand(1)->getType())); + MadeChange = true; + } + APInt LeftDemanded(OpWidth, 1); + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + if (LHSUndefElts[0]) + UndefElts = EltMask; + else + UndefElts.clearAllBits(); + break; + } + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) {