Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -957,6 +957,26 @@ } //===----------------------------------------------------------------------===// +// Matcher for LoadInst classes +// + +template struct LoadClass_match { + Op_t Op; + + LoadClass_match(const Op_t &OpMatch) : Op(OpMatch) {} + + template bool match(OpTy *V) { + if (auto *LI = dyn_cast(V)) + return Op.match(LI->getPointerOperand()); + return false; + } +}; + +/// Matches LoadInst. +template inline LoadClass_match m_Load(const OpTy &Op) { + return LoadClass_match(Op); +} +//===----------------------------------------------------------------------===// // Matchers for unary operators // Index: lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -22,9 +22,11 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "instcombine" @@ -561,6 +563,28 @@ return NewStore; } +/// Returns true if instruction represent minmax pattern like: +/// select ((cmp load V1, load V2), V1, V2). +static bool isMinMaxPattern(Value *V) { + assert(V->getType()->isPointerTy() && "Expected pointer type."); + // Ignore possible ty* to ixx* bitcast. + V = peekThroughBitcast(V); + // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax + // pattern. + CmpInst::Predicate Pred; + Instruction *L1; + Instruction *L2; + Value *LHS; + Value *RHS; + if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), + m_Value(LHS), m_Value(RHS)))) + return false; + return (match(L1, m_Load(m_Specific(LHS))) && + match(L2, m_Load(m_Specific(RHS)))) || + (match(L1, m_Load(m_Specific(RHS))) && + match(L2, m_Load(m_Specific(LHS)))); +} + /// \brief Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// @@ -598,10 +622,14 @@ // integers instead of any other type. We only do this when the loaded type // is sized and has a size exactly the same as its store size and the store // size is a legal integer type. + // Do not perform canonicalization if minmax pattern is found (to avoid + // infinite loop). if (!Ty->isIntegerTy() && Ty->isSized() && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && - !DL.isNonIntegralPointerType(Ty)) { + !DL.isNonIntegralPointerType(Ty) && + !isMinMaxPattern( + peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) { if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast(U); return SI && SI->getPointerOperand() != &LI && @@ -1298,6 +1326,30 @@ return false; } +/// Converts store (bitcast (load (bitcast (select ...)))) to +/// store (load (select ...)), where select is minmax: +/// select ((cmp load V1, load V2), V1, V2). +bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, StoreInst &SI) { + // bitcast? + Value *StoreAddr; + if (!match(SI.getPointerOperand(), m_BitCast(m_Value(StoreAddr)))) + return false; + // load? integer? + Value *LoadAddr; + if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) + return false; + auto *LI = cast(SI.getValueOperand()); + if (!LI->getType()->isIntegerTy()) + return false; + if (!isMinMaxPattern(LoadAddr)) + return false; + + LoadInst *NewLI = combineLoadToNewType( + IC, *LI, LoadAddr->getType()->getPointerElementType()); + combineStoreToNewValue(IC, SI, NewLI); + return true; +} + Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); @@ -1322,6 +1374,10 @@ if (unpackStoreToAggregate(*this, SI)) return eraseInstFromFunction(SI); + // Try to decanonicalize store (bitcast (load (bitcast (select ...)))). + if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) + return eraseInstFromFunction(SI); + // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { Worklist.Add(NewGEPI); Index: test/Transforms/InstCombine/load-bitcast-select.ll =================================================================== --- test/Transforms/InstCombine/load-bitcast-select.ll +++ test/Transforms/InstCombine/load-bitcast-select.ll @@ -21,11 +21,8 @@ ; CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[TMP2:%.*]] = load float, float* [[ARRAYIDX2]], align 4 ; CHECK-NEXT: [[CMP_I:%.*]] = fcmp fast olt float [[TMP1]], [[TMP2]] -; CHECK-NEXT: [[__B___A_I:%.*]] = select i1 [[CMP_I]], float* [[ARRAYIDX2]], float* [[ARRAYIDX]] -; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[__B___A_I]] to i32* -; CHECK-NEXT: [[TMP4:%.*]] = load i32, i32* [[TMP3]], align 4 -; CHECK-NEXT: [[TMP5:%.*]] = bitcast float* [[ARRAYIDX]] to i32* -; CHECK-NEXT: store i32 [[TMP4]], i32* [[TMP5]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[CMP_I]], float [[TMP2]], float [[TMP1]] +; CHECK-NEXT: store float [[TMP3]], float* [[ARRAYIDX]], align 4 ; CHECK-NEXT: [[INC]] = add nuw nsw i32 [[I_0]], 1 ; CHECK-NEXT: br label [[FOR_COND]] ; @@ -91,11 +88,8 @@ ; CHECK-NEXT: [[LD1:%.*]] = load float, float* [[LOADADDR1:%.*]], align 4 ; CHECK-NEXT: [[LD2:%.*]] = load float, float* [[LOADADDR2:%.*]], align 4 ; CHECK-NEXT: [[COND:%.*]] = fcmp ogt float [[LD1]], [[LD2]] -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[COND]], float* [[LOADADDR1]], float* [[LOADADDR2]] -; CHECK-NEXT: [[INT_LOAD_ADDR:%.*]] = bitcast float* [[SEL]] to i32* -; CHECK-NEXT: [[LD:%.*]] = load i32, i32* [[INT_LOAD_ADDR]], align 4 -; CHECK-NEXT: [[INT_STORE_ADDR:%.*]] = bitcast float* [[STOREADDR:%.*]] to i32* -; CHECK-NEXT: store i32 [[LD]], i32* [[INT_STORE_ADDR]], align 4 +; CHECK-NEXT: [[LD3:%.*]] = select i1 [[COND]], float [[LD1]], float [[LD2]] +; CHECK-NEXT: store float [[LD3]], float* [[STOREADDR:%.*]], align 4 ; CHECK-NEXT: ret void ; %ld1 = load float, float* %loadaddr1, align 4