diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -526,7 +526,7 @@ /// Returns true if instruction represent minmax pattern like: /// select ((cmp load V1, load V2), V1, V2). -static bool isMinMaxWithLoads(Value *V) { +static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { assert(V->getType()->isPointerTy() && "Expected pointer type."); // Ignore possible ty* to ixx* bitcast. V = peekThroughBitcast(V); @@ -540,6 +540,7 @@ if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), m_Value(LHS), m_Value(RHS)))) return false; + LoadTy = L1->getType(); return (match(L1, m_Load(m_Specific(LHS))) && match(L2, m_Load(m_Specific(RHS)))) || (match(L1, m_Load(m_Specific(RHS))) && @@ -585,13 +586,15 @@ // size is a legal integer type. // Do not perform canonicalization if minmax pattern is found (to avoid // infinite loop). + Type *Dummy; if (!Ty->isIntegerTy() && Ty->isSized() && !(Ty->isVectorTy() && Ty->getVectorIsScalable()) && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) && !isMinMaxWithLoads( - peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) { + peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true), + Dummy)) { if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast(U); return SI && SI->getPointerOperand() != &LI && @@ -1323,7 +1326,13 @@ auto *LI = cast(SI.getValueOperand()); if (!LI->getType()->isIntegerTy()) return false; - if (!isMinMaxWithLoads(LoadAddr)) + Type *CmpLoadTy; + if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy)) + return false; + + // Make sure we're not changing the size of the load/store. + const auto &DL = IC.getDataLayout(); + if (DL.getTypeStoreSizeInBits(LI->getType()) != DL.getTypeStoreSizeInBits(CmpLoadTy)) return false; if (!all_of(LI->users(), [LI, LoadAddr](User *U) { @@ -1336,7 +1345,7 @@ IC.Builder.SetInsertPoint(LI); LoadInst *NewLI = combineLoadToNewType( - IC, *LI, LoadAddr->getType()->getPointerElementType()); + IC, *LI, CmpLoadTy); // Replace all the stores with stores of the newly loaded value. for (auto *UI : LI->users()) { auto *USI = cast(UI); diff --git a/llvm/test/Transforms/InstCombine/PR37526.ll b/llvm/test/Transforms/InstCombine/PR37526.ll --- a/llvm/test/Transforms/InstCombine/PR37526.ll +++ b/llvm/test/Transforms/InstCombine/PR37526.ll @@ -3,11 +3,14 @@ define void @PR37526(i32* %pz, i32* %px, i32* %py) { ; CHECK-LABEL: @PR37526( +; CHECK-NEXT: [[T1:%.*]] = bitcast i32* [[PZ:%.*]] to i64* ; CHECK-NEXT: [[T2:%.*]] = load i32, i32* [[PY:%.*]], align 4 ; CHECK-NEXT: [[T3:%.*]] = load i32, i32* [[PX:%.*]], align 4 ; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[T2]], [[T3]] -; CHECK-NEXT: [[R1:%.*]] = select i1 [[CMP]], i32 [[T3]], i32 [[T2]] -; CHECK-NEXT: store i32 [[R1]], i32* [[PZ:%.*]], align 4 +; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], i32* [[PX]], i32* [[PY]] +; CHECK-NEXT: [[BC:%.*]] = bitcast i32* [[SELECT]] to i64* +; CHECK-NEXT: [[R:%.*]] = load i64, i64* [[BC]], align 4 +; CHECK-NEXT: store i64 [[R]], i64* [[T1]], align 4 ; CHECK-NEXT: ret void ; %t1 = bitcast i32* %pz to i64*