Index: include/llvm/IR/Constants.h =================================================================== --- include/llvm/IR/Constants.h +++ include/llvm/IR/Constants.h @@ -761,6 +761,10 @@ /// i32/i64/float/double) and must be a ConstantFP or ConstantInt. static Constant *getSplat(unsigned NumElts, Constant *Elt); + /// Returns true if this is a splat constant, meaning that all elements have + /// the same value. + bool isSplat() const; + /// If this is a splat constant, meaning that all of the elements have the /// same value, return that value. Otherwise return NULL. Constant *getSplatValue() const; Index: lib/IR/Constants.cpp =================================================================== --- lib/IR/Constants.cpp +++ lib/IR/Constants.cpp @@ -44,9 +44,10 @@ // Equivalent for a vector of -0.0's. if (const ConstantDataVector *CV = dyn_cast(this)) - if (ConstantFP *SplatCFP = dyn_cast_or_null(CV->getSplatValue())) - if (SplatCFP && SplatCFP->isZero() && SplatCFP->isNegative()) - return true; + if (CV->isSplat()) + if (CV->getElementType()->isFloatingPointTy()) + if (CV->getElementAsAPFloat(0).isNegZero()) + return true; if (const ConstantVector *CV = dyn_cast(this)) if (ConstantFP *SplatCFP = dyn_cast_or_null(CV->getSplatValue())) @@ -70,9 +71,10 @@ // Equivalent for a vector of -0.0's. if (const ConstantDataVector *CV = dyn_cast(this)) - if (ConstantFP *SplatCFP = dyn_cast_or_null(CV->getSplatValue())) - if (SplatCFP && SplatCFP->isZero()) - return true; + if (CV->isSplat()) + if (CV->getElementType()->isFloatingPointTy()) + if (CV->getElementAsAPFloat(0).isZero()) + return true; if (const ConstantVector *CV = dyn_cast(this)) if (ConstantFP *SplatCFP = dyn_cast_or_null(CV->getSplatValue())) @@ -113,9 +115,14 @@ return Splat->isAllOnesValue(); // Check for constant vectors which are splats of -1 values. - if (const ConstantDataVector *CV = dyn_cast(this)) - if (Constant *Splat = CV->getSplatValue()) - return Splat->isAllOnesValue(); + if (const ConstantDataVector *CV = dyn_cast(this)) { + if (CV->isSplat()) { + Type *Ty = CV->getElementType(); + if (Ty->isFloatingPointTy()) + return CV->getElementAsAPFloat(0).bitcastToAPInt().isAllOnesValue(); + return CV->getElementAsInteger(0) == maxUIntN(Ty->getIntegerBitWidth()); + } + } return false; } @@ -135,9 +142,14 @@ return Splat->isOneValue(); // Check for constant vectors which are splats of 1 values. - if (const ConstantDataVector *CV = dyn_cast(this)) - if (Constant *Splat = CV->getSplatValue()) - return Splat->isOneValue(); + if (const ConstantDataVector *CV = dyn_cast(this)) { + if (CV->isSplat()) { + Type *Ty = CV->getElementType(); + if (Ty->isFloatingPointTy()) + return CV->getElementAsAPFloat(0).bitcastToAPInt().isOneValue(); + return CV->getElementAsInteger(0) == 1; + } + } return false; } @@ -157,9 +169,15 @@ return Splat->isMinSignedValue(); // Check for constant vectors which are splats of INT_MIN values. - if (const ConstantDataVector *CV = dyn_cast(this)) - if (Constant *Splat = CV->getSplatValue()) - return Splat->isMinSignedValue(); + if (const ConstantDataVector *CV = dyn_cast(this)) { + if (CV->isSplat()) { + Type *Ty = CV->getElementType(); + if (Ty->isFloatingPointTy()) + return CV->getElementAsAPFloat(0).bitcastToAPInt().isMinSignedValue(); + return CV->getElementAsInteger(0) == + -(uint64_t)minIntN(Ty->getIntegerBitWidth()); + } + } return false; } @@ -179,9 +197,15 @@ return Splat->isNotMinSignedValue(); // Check for constant vectors which are splats of INT_MIN values. - if (const ConstantDataVector *CV = dyn_cast(this)) - if (Constant *Splat = CV->getSplatValue()) - return Splat->isNotMinSignedValue(); + if (const ConstantDataVector *CV = dyn_cast(this)) { + if (CV->isSplat()) { + Type *Ty = CV->getElementType(); + if (Ty->isFloatingPointTy()) + return !CV->getElementAsAPFloat(0).bitcastToAPInt().isMinSignedValue(); + return CV->getElementAsInteger(0) != + -(uint64_t)minIntN(Ty->getIntegerBitWidth()); + } + } // It *may* contain INT_MIN, we can't tell. return false; @@ -2627,17 +2651,21 @@ return Str.drop_back().find(0) == StringRef::npos; } -Constant *ConstantDataVector::getSplatValue() const { +bool ConstantDataVector::isSplat() const { const char *Base = getRawDataValues().data(); // Compare elements 1+ to the 0'th element. unsigned EltSize = getElementByteSize(); for (unsigned i = 1, e = getNumElements(); i != e; ++i) if (memcmp(Base, Base+i*EltSize, EltSize)) - return nullptr; + return false; + return true; +} + +Constant *ConstantDataVector::getSplatValue() const { // If they're all the same, return the 0th one as a representative. - return getElementAsConstant(0); + return isSplat() ? getElementAsConstant(0) : nullptr; } //===----------------------------------------------------------------------===//