Index: lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; #define DEBUG_TYPE "instcombine" @@ -1015,15 +1016,43 @@ // but it would not be valid if we transformed it to load from null // unconditionally. // + // load bitcast (select (Cond, &V1, &V2)) --> select(Cond, load bitcast + // &V1, load bitcast &V2) iff this load is part of store load, bitcast. + // It is used to solve the problem with load/store canonicalization. + auto *BC = dyn_cast(Op); + if (BC) { + Op = BC->getOperand(0); + if (!Op->hasOneUse()) + return nullptr; + // Load instruction must have single user. + if (!LI.hasOneUse()) + return nullptr; + // The user must be StoreInst. + auto *USI = dyn_cast(*LI.materialized_user_begin()); + if (!USI) + return nullptr; + // The pointer operand must be BitCastInst. + if (!isa(USI->getPointerOperand())) + return nullptr; + } + if (SelectInst *SI = dyn_cast(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). unsigned Align = LI.getAlignment(); if (isSafeToLoadUnconditionally(SI->getOperand(1), Align, DL, SI) && isSafeToLoadUnconditionally(SI->getOperand(2), Align, DL, SI)) { - LoadInst *V1 = Builder.CreateLoad(SI->getOperand(1), - SI->getOperand(1)->getName()+".val"); - LoadInst *V2 = Builder.CreateLoad(SI->getOperand(2), - SI->getOperand(2)->getName()+".val"); + Value *AddrLHS = SI->getOperand(1); + Value *AddrRHS = SI->getOperand(2); + if (BC) { + AddrLHS = Builder.CreateBitCast(AddrLHS, BC->getDestTy(), + AddrLHS->getName() + ".cast"); + propagateIRFlags(AddrLHS, BC); + AddrRHS = Builder.CreateBitCast(AddrRHS, BC->getDestTy(), + AddrRHS->getName() + ".cast"); + propagateIRFlags(AddrRHS, BC); + } + LoadInst *V1 = Builder.CreateLoad(AddrLHS, AddrLHS->getName() + ".val"); + LoadInst *V2 = Builder.CreateLoad(AddrRHS, AddrRHS->getName() + ".val"); assert(LI.isUnordered() && "implied by above"); V1->setAlignment(Align); V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); @@ -1035,14 +1064,26 @@ // load (select (cond, null, P)) -> load P if (isa(SI->getOperand(1)) && LI.getPointerAddressSpace() == 0) { - LI.setOperand(0, SI->getOperand(2)); + Value *Operand = SI->getOperand(2); + if (BC) { + Operand = Builder.CreateBitCast(Operand, BC->getDestTy(), + Operand->getName() + ".cast"); + propagateIRFlags(Operand, BC); + } + LI.setOperand(0, Operand); return &LI; } // load (select (cond, P, null)) -> load P if (isa(SI->getOperand(2)) && LI.getPointerAddressSpace() == 0) { - LI.setOperand(0, SI->getOperand(1)); + Value *Operand = SI->getOperand(1); + if (BC) { + Operand = Builder.CreateBitCast(Operand, BC->getDestTy(), + Operand->getName() + ".cast"); + propagateIRFlags(Operand, BC); + } + LI.setOperand(0, Operand); return &LI; } } Index: test/Transforms/InstCombine/load-bitcast-select.ll =================================================================== --- test/Transforms/InstCombine/load-bitcast-select.ll +++ test/Transforms/InstCombine/load-bitcast-select.ll @@ -21,11 +21,8 @@ ; CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[TMP2:%.*]] = load float, float* [[ARRAYIDX2]], align 4 ; CHECK-NEXT: [[CMP_I:%.*]] = fcmp fast olt float [[TMP1]], [[TMP2]] -; CHECK-NEXT: [[__B___A_I:%.*]] = select i1 [[CMP_I]], float* [[ARRAYIDX2]], float* [[ARRAYIDX]] -; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[__B___A_I]] to i32* -; CHECK-NEXT: [[TMP4:%.*]] = load i32, i32* [[TMP3]], align 4 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast float* [[ARRAYIDX]] to i32* -; CHECK-NEXT: store i32 [[TMP4]], i32* [[TMP5]], align 4 +; CHECK-NEXT: [[DOTV:%.*]] = select i1 [[CMP_I]], float [[TMP2]], float [[TMP1]] +; CHECK-NEXT: store float [[DOTV]], float* [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[INC]] = add nuw nsw i32 [[I_0]], 1 ; CHECK-NEXT: br label [[FOR_COND]] ;