Index: include/llvm/Analysis/ValueTracking.h =================================================================== --- include/llvm/Analysis/ValueTracking.h +++ include/llvm/Analysis/ValueTracking.h @@ -16,6 +16,7 @@ #define LLVM_ANALYSIS_VALUETRACKING_H #include "llvm/ADT/ArrayRef.h" +#include "llvm/IR/Instruction.h" #include "llvm/Support/DataTypes.h" namespace llvm { @@ -275,7 +276,21 @@ }; /// Pattern match integer [SU]MIN, [SU]MAX and ABS idioms, returning the kind /// and providing the out parameter results if we successfully match. - SelectPatternFlavor matchSelectPattern(Value *V, Value *&LHS, Value *&RHS); + /// + /// If CastOp is not nullptr, also match MIN/MAX idioms where the type does + /// not match that of the original select. If this is the case, the cast + /// operation (one of Trunc,SExt,Zext) that must be done to transform the + /// type of LHS and RHS into the type of V is returned in CastOp. + /// + /// For example: + /// %1 = icmp slt i32 %a, i32 4 + /// %2 = sext i32 %a to i64 + /// %3 = select i1 %1, i64 %2, i64 4 + /// + /// -> LHS = %a, RHS = i32 4, *CastOp = Instruction::SExt + /// + SelectPatternFlavor matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp=nullptr); } // end namespace llvm Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -3205,20 +3205,10 @@ return OverflowResult::MayOverflow; } -SelectPatternFlavor llvm::matchSelectPattern(Value *V, - Value *&LHS, Value *&RHS) { - SelectInst *SI = dyn_cast(V); - if (!SI) return SPF_UNKNOWN; - - ICmpInst *ICI = dyn_cast(SI->getCondition()); - if (!ICI) return SPF_UNKNOWN; - - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - +static SelectPatternFlavor matchSelectPattern(ICmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + Value *&LHS, Value *&RHS) { LHS = CmpLHS; RHS = CmpRHS; @@ -3285,3 +3275,116 @@ return SPF_UNKNOWN; } + +static Constant *getTruncatedConstant(Constant *C, Type *ToType) { + if (ConstantInt *CI = dyn_cast(C)) { + return ConstantInt::get(ToType, CI->getValue().trunc(ToType->getIntegerBitWidth())); + } else if (isa(C)) { + return ConstantAggregateZero::get(ToType); + } else if (ConstantVector *CV = dyn_cast(C)) { + SmallVector Cs; + for (unsigned I = 0, E = ToType->getVectorNumElements(); I != E; ++I) + Cs.push_back(getTruncatedConstant(CV->getAggregateElement(I), + ToType->getVectorElementType())); + return ConstantVector::get(Cs); + } else { + return nullptr; + } +} + +static Constant *getSExtConstant(Constant *C, Type *ToType) { + if (ConstantInt *CI = dyn_cast(C)) { + return ConstantInt::get(ToType, CI->getValue().sext(ToType->getIntegerBitWidth())); + } else if (isa(C)) { + return ConstantAggregateZero::get(ToType); + } else if (ConstantVector *CV = dyn_cast(C)) { + SmallVector Cs; + for (unsigned I = 0, E = ToType->getVectorNumElements(); I != E; ++I) + Cs.push_back(getSExtConstant(CV->getAggregateElement(I), + ToType->getVectorElementType())); + return ConstantVector::get(Cs); + } else { + return nullptr; + } +} + +static Constant *getZExtConstant(Constant *C, Type *ToType) { + if (ConstantInt *CI = dyn_cast(C)) { + return ConstantInt::get(ToType, CI->getValue().zext(ToType->getIntegerBitWidth())); + } else if (isa(C)) { + return ConstantAggregateZero::get(ToType); + } else if (ConstantVector *CV = dyn_cast(C)) { + SmallVector Cs; + for (unsigned I = 0, E = ToType->getVectorNumElements(); I != E; ++I) + Cs.push_back(getZExtConstant(CV->getAggregateElement(I), + ToType->getVectorElementType())); + return ConstantVector::get(Cs); + } else { + return nullptr; + } +} + +SelectPatternFlavor llvm::matchSelectPattern(Value *V, + Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp) { + SelectInst *SI = dyn_cast(V); + if (!SI) return SPF_UNKNOWN; + + ICmpInst *ICI = dyn_cast(SI->getCondition()); + if (!ICI) return SPF_UNKNOWN; + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + + // Deal with type mismatches. + if (CastOp && CmpLHS->getType() != TrueVal->getType()) { + for (int Swap = 0; Swap <= 1; ++Swap) { + CastInst *Op1 = dyn_cast(Swap ? FalseVal : TrueVal); + Constant *Op2 = dyn_cast(Swap ? TrueVal : FalseVal); + if (!Op1 || !Op2) + continue; + Value *Castee = Op1->getOperand(0); + *CastOp = Op1->getOpcode(); + + switch (Op1->getOpcode()) { + default: break; + case Instruction::SExt: { + Constant *Trunc = getTruncatedConstant(Op2, Castee->getType()); + if (Trunc && ICI->isSigned()) + return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, + Swap ? Trunc : Castee, + Swap ? Castee : Trunc, + LHS, RHS); + break; + } + case Instruction::ZExt: { + Constant *Trunc = getTruncatedConstant(Op2, Castee->getType()); + if (Trunc && !ICI->isSigned()) + return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, + Swap ? Trunc : Castee, + Swap ? Castee : Trunc, + LHS, RHS); + break; + } + case Instruction::Trunc: { + Constant *C; + if (ICI->isSigned()) + C = getSExtConstant(Op2, Castee->getType()); + else + C = getZExtConstant(Op2, Castee->getType()); + if (C) + return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, + Swap ? C : Castee, + Swap ? Castee : C, + LHS, RHS); + break; + } + } + } + } + return ::matchSelectPattern(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, + LHS, RHS); +}