Index: llvm/include/llvm/Analysis/ConstantFolding.h =================================================================== --- llvm/include/llvm/Analysis/ConstantFolding.h +++ llvm/include/llvm/Analysis/ConstantFolding.h @@ -148,6 +148,12 @@ Constant *ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, const DataLayout &DL); +/// If C is a uniform value where all bits are the same (either all zero, all +/// ones, all undef or all poison), return the corresponding uniform value in +/// the new type. If the value is not uniform or the result cannot be +/// represented, return null. +Constant *ConstantFoldLoadFromUniformValue(Constant *C, Type *Ty); + /// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a /// getelementptr constantexpr, return the constant value being addressed by the /// constant expression, or null if something is funny and we can't decide. Index: llvm/lib/Analysis/ConstantFolding.cpp =================================================================== --- llvm/lib/Analysis/ConstantFolding.cpp +++ llvm/lib/Analysis/ConstantFolding.cpp @@ -106,11 +106,8 @@ "Invalid constantexpr bitcast!"); // Catch the obvious splat cases. - if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy()) - return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() && - !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! - return Constant::getAllOnesValue(DestTy); + if (Constant *Res = ConstantFoldLoadFromUniformValue(C, DestTy)) + return Res; if (auto *VTy = dyn_cast(C->getType())) { // Handle a vector->scalar integer/fp cast. @@ -362,16 +359,8 @@ // Catch the obvious splat cases (since all-zeros can coerce non-integral // pointers legally). - if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy()) - return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && - (DestTy->isIntegerTy() || DestTy->isFloatingPointTy() || - DestTy->isVectorTy()) && - !DestTy->isX86_AMXTy() && !DestTy->isX86_MMXTy() && - !DestTy->isPtrOrPtrVectorTy()) - // Get ones when the input is trivial, but - // only for supported types inside getAllOnesValue. - return Constant::getAllOnesValue(DestTy); + if (Constant *Res = ConstantFoldLoadFromUniformValue(C, DestTy)) + return Res; // If the type sizes are the same and a cast is legal, just directly // cast the constant. @@ -704,16 +693,13 @@ Offset, DL)) return Result; - // If this load comes from anywhere in a constant global, and if the global - // is all undef or zero, we know what it loads. - if (auto *GV = dyn_cast(getUnderlyingObject(C))) { - if (GV->isConstant() && GV->hasDefinitiveInitializer()) { - if (GV->getInitializer()->isNullValue()) - return Constant::getNullValue(Ty); - if (isa(GV->getInitializer())) - return UndefValue::get(Ty); - } - } + // If this load comes from anywhere in a uniform constant global, the value + // is always the same, regardless of the loaded offset. + if (auto *GV = dyn_cast(getUnderlyingObject(C))) + if (GV->isConstant() && GV->hasDefinitiveInitializer()) + if (Constant *Res = + ConstantFoldLoadFromUniformValue(GV->getInitializer(), Ty)) + return Res; return nullptr; } @@ -724,6 +710,19 @@ return ConstantFoldLoadFromConstPtr(C, Ty, Offset, DL); } +Constant *llvm::ConstantFoldLoadFromUniformValue(Constant *C, Type *Ty) { + if (isa(C)) + return PoisonValue::get(Ty); + if (isa(C)) + return UndefValue::get(Ty); + if (C->isNullValue() && !Ty->isX86_MMXTy() && !Ty->isX86_AMXTy()) + return Constant::getNullValue(Ty); + if (C->isAllOnesValue() && + (Ty->isIntOrIntVectorTy() || Ty->isFPOrFPVectorTy())) + return Constant::getAllOnesValue(Ty); + return nullptr; +} + namespace { /// One of Op0/Op1 is a constant expression. Index: llvm/lib/Transforms/IPO/GlobalOpt.cpp =================================================================== --- llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -305,8 +305,9 @@ else if (auto *LI = dyn_cast(U)) { // A load from zeroinitializer is always zeroinitializer, regardless of // any applied offset. - if (Init->isNullValue()) { - LI->replaceAllUsesWith(Constant::getNullValue(LI->getType())); + if (Constant *Res = + ConstantFoldLoadFromUniformValue(Init, LI->getType())) { + LI->replaceAllUsesWith(Res); EraseFromParent(LI); continue; } Index: llvm/test/Transforms/GlobalOpt/x86_mmx_load.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/GlobalOpt/x86_mmx_load.ll @@ -0,0 +1,12 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -globalopt < %s | FileCheck %s + +@m64 = internal global <1 x i64> zeroinitializer + +define i32 @load_mmx() { +; CHECK-LABEL: @load_mmx( +; CHECK-NEXT: ret i32 0 +; + %temp = load x86_mmx, x86_mmx* bitcast (<1 x i64>* @m64 to x86_mmx*) + ret i32 0 +} Index: llvm/test/Transforms/InstSimplify/ConstProp/loads.ll =================================================================== --- llvm/test/Transforms/InstSimplify/ConstProp/loads.ll +++ llvm/test/Transforms/InstSimplify/ConstProp/loads.ll @@ -280,3 +280,16 @@ %v = load { i64, i64 }, { i64, i64 }* @g3 ret { i64, i64 } %v } + +@m64 = internal constant [2 x i64] zeroinitializer +@idx = external global i32 + +; This should not try to create an x86_mmx null value. +define x86_mmx @load_mmx() { +; CHECK-LABEL: @load_mmx( +; CHECK-NEXT: [[TEMP:%.*]] = load x86_mmx, x86_mmx* bitcast (i64* getelementptr ([2 x i64], [2 x i64]* @m64, i64 0, i64 ptrtoint (i32* @idx to i64)) to x86_mmx*), align 8 +; CHECK-NEXT: ret x86_mmx [[TEMP]] +; + %temp = load x86_mmx, x86_mmx* bitcast (i64* getelementptr ([2 x i64], [2 x i64]* @m64, i64 0, i64 ptrtoint (i32* @idx to i64)) to x86_mmx*) + ret x86_mmx %temp +}