diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -611,7 +611,7 @@ LoadInst *RootInsert = nullptr; bool FoundRoot = false; uint64_t LoadSize = 0; - Value *Shift = nullptr; + const APInt *Shift = nullptr; Type *ZextType; AAMDNodes AATags; }; @@ -621,7 +621,7 @@ // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, AliasAnalysis &AA) { - Value *ShAmt2 = nullptr; + const APInt *ShAmt2 = nullptr; Value *X; Instruction *L1, *L2; @@ -629,7 +629,7 @@ if (match(V, m_OneUse(m_c_Or( m_Value(X), m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), - m_Value(ShAmt2)))))) || + m_APInt(ShAmt2)))))) || match(V, m_OneUse(m_Or(m_Value(X), m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) @@ -640,11 +640,11 @@ // Check if the pattern has loads LoadInst *LI1 = LOps.Root; - Value *ShAmt1 = LOps.Shift; + const APInt *ShAmt1 = LOps.Shift; if (LOps.FoundRoot == false && (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), - m_Value(ShAmt1)))))) { + m_APInt(ShAmt1)))))) { LI1 = dyn_cast(L1); } LoadInst *LI2 = dyn_cast(L2); @@ -719,12 +719,11 @@ std::swap(ShAmt1, ShAmt2); // Find Shifts values. - const APInt *Temp; uint64_t Shift1 = 0, Shift2 = 0; - if (ShAmt1 && match(ShAmt1, m_APInt(Temp))) - Shift1 = Temp->getZExtValue(); - if (ShAmt2 && match(ShAmt2, m_APInt(Temp))) - Shift2 = Temp->getZExtValue(); + if (ShAmt1) + Shift1 = ShAmt1->getZExtValue(); + if (ShAmt2) + Shift2 = ShAmt2->getZExtValue(); // First load is always LI1. This is where we put the new load. // Use the merged load size available from LI1 for forward loads. @@ -816,7 +815,7 @@ // Check if shift needed. We need to shift with the amount of load1 // shift if not zero. if (LOps.Shift) - NewOp = Builder.CreateShl(NewOp, LOps.Shift); + NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift)); I.replaceAllUsesWith(NewOp); return true; diff --git a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll --- a/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll +++ b/llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll @@ -2253,3 +2253,53 @@ %o3 = or i32 %o2, %e1 ret i32 %o3 } + +define i64 @loadCombine_nonConstShift1(ptr %arg, i8 %b) { +; ALL-LABEL: @loadCombine_nonConstShift1( +; ALL-NEXT: [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1 +; ALL-NEXT: [[LD0:%.*]] = load i8, ptr [[ARG]], align 1 +; ALL-NEXT: [[LD1:%.*]] = load i8, ptr [[G1]], align 1 +; ALL-NEXT: [[Z0:%.*]] = zext i8 [[LD0]] to i64 +; ALL-NEXT: [[Z1:%.*]] = zext i8 [[LD1]] to i64 +; ALL-NEXT: [[Z6:%.*]] = zext i8 [[B:%.*]] to i64 +; ALL-NEXT: [[S0:%.*]] = shl i64 [[Z0]], [[Z6]] +; ALL-NEXT: [[S1:%.*]] = shl i64 [[Z1]], 8 +; ALL-NEXT: [[O7:%.*]] = or i64 [[S0]], [[S1]] +; ALL-NEXT: ret i64 [[O7]] +; + %g1 = getelementptr i8, ptr %arg, i64 1 + %ld0 = load i8, ptr %arg, align 1 + %ld1 = load i8, ptr %g1, align 1 + %z0 = zext i8 %ld0 to i64 + %z1 = zext i8 %ld1 to i64 + %z6 = zext i8 %b to i64 + %s0 = shl i64 %z0, %z6 + %s1 = shl i64 %z1, 8 + %o7 = or i64 %s0, %s1 + ret i64 %o7 +} + +define i64 @loadCombine_nonConstShift2(ptr %arg, i8 %b) { +; ALL-LABEL: @loadCombine_nonConstShift2( +; ALL-NEXT: [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1 +; ALL-NEXT: [[LD0:%.*]] = load i8, ptr [[ARG]], align 1 +; ALL-NEXT: [[LD1:%.*]] = load i8, ptr [[G1]], align 1 +; ALL-NEXT: [[Z0:%.*]] = zext i8 [[LD0]] to i64 +; ALL-NEXT: [[Z1:%.*]] = zext i8 [[LD1]] to i64 +; ALL-NEXT: [[Z6:%.*]] = zext i8 [[B:%.*]] to i64 +; ALL-NEXT: [[S0:%.*]] = shl i64 [[Z0]], [[Z6]] +; ALL-NEXT: [[S1:%.*]] = shl i64 [[Z1]], 8 +; ALL-NEXT: [[O7:%.*]] = or i64 [[S1]], [[S0]] +; ALL-NEXT: ret i64 [[O7]] +; + %g1 = getelementptr i8, ptr %arg, i64 1 + %ld0 = load i8, ptr %arg, align 1 + %ld1 = load i8, ptr %g1, align 1 + %z0 = zext i8 %ld0 to i64 + %z1 = zext i8 %ld1 to i64 + %z6 = zext i8 %b to i64 + %s0 = shl i64 %z0, %z6 + %s1 = shl i64 %z1, 8 + %o7 = or i64 %s1, %s0 + ret i64 %o7 +}