Index: include/llvm/Analysis/ConstantFolding.h =================================================================== --- include/llvm/Analysis/ConstantFolding.h +++ include/llvm/Analysis/ConstantFolding.h @@ -65,6 +65,17 @@ Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI = nullptr); +/// \brief Attempt to constant fold a binary operation with the specified +/// operands. If it fails, it returns a constant expression of the specified +/// operands. +Constant *ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, + Constant *RHS, const DataLayout &DL); + +/// \brief Attempt to constant fold a cast with the specified operand. If it +/// fails, it returns a constant expression of the specified operand. +Constant *ConstantFoldCastOperand(unsigned Opcode, Constant *C, Type *DestTy, + const DataLayout &DL); + /// ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue /// instruction with the specified operands and indices. The constant result is /// returned if successful; if not, null is returned. Index: lib/Analysis/ConstantFolding.cpp =================================================================== --- lib/Analysis/ConstantFolding.cpp +++ lib/Analysis/ConstantFolding.cpp @@ -1010,14 +1010,11 @@ const DataLayout &DL, const TargetLibraryInfo *TLI) { // Handle easy binops first. - if (Instruction::isBinaryOp(Opcode)) { - if (isa(Ops[0]) || isa(Ops[1])) { - if (Constant *C = SymbolicallyEvaluateBinop(Opcode, Ops[0], Ops[1], DL)) - return C; - } + if (Instruction::isBinaryOp(Opcode)) + return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL); - return ConstantExpr::get(Opcode, Ops[0], Ops[1]); - } + if (Instruction::isCast(Opcode)) + return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); switch (Opcode) { default: return nullptr; @@ -1028,58 +1025,6 @@ if (canConstantFoldCallTo(F)) return ConstantFoldCall(F, Ops.slice(0, Ops.size() - 1), TLI); return nullptr; - case Instruction::PtrToInt: - // If the input is a inttoptr, eliminate the pair. This requires knowing - // the width of a pointer, so it can't be done in ConstantExpr::getCast. - if (ConstantExpr *CE = dyn_cast(Ops[0])) { - if (CE->getOpcode() == Instruction::IntToPtr) { - Constant *Input = CE->getOperand(0); - unsigned InWidth = Input->getType()->getScalarSizeInBits(); - unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType()); - if (PtrWidth < InWidth) { - Constant *Mask = - ConstantInt::get(CE->getContext(), - APInt::getLowBitsSet(InWidth, PtrWidth)); - Input = ConstantExpr::getAnd(Input, Mask); - } - // Do a zext or trunc to get to the dest size. - return ConstantExpr::getIntegerCast(Input, DestTy, false); - } - } - return ConstantExpr::getCast(Opcode, Ops[0], DestTy); - case Instruction::IntToPtr: - // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if - // the int size is >= the ptr size and the address spaces are the same. - // This requires knowing the width of a pointer, so it can't be done in - // ConstantExpr::getCast. - if (ConstantExpr *CE = dyn_cast(Ops[0])) { - if (CE->getOpcode() == Instruction::PtrToInt) { - Constant *SrcPtr = CE->getOperand(0); - unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType()); - unsigned MidIntSize = CE->getType()->getScalarSizeInBits(); - - if (MidIntSize >= SrcPtrSize) { - unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace(); - if (SrcAS == DestTy->getPointerAddressSpace()) - return FoldBitCast(CE->getOperand(0), DestTy, DL); - } - } - } - - return ConstantExpr::getCast(Opcode, Ops[0], DestTy); - case Instruction::Trunc: - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPTrunc: - case Instruction::FPExt: - case Instruction::UIToFP: - case Instruction::SIToFP: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::AddrSpaceCast: - return ConstantExpr::getCast(Opcode, Ops[0], DestTy); - case Instruction::BitCast: - return FoldBitCast(Ops[0], DestTy, DL); case Instruction::Select: return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); case Instruction::ExtractElement: @@ -1176,14 +1121,83 @@ Predicate, CE0->getOperand(1), Ops1, DL, TLI); unsigned OpC = Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; - Constant *Ops[] = { LHS, RHS }; - return ConstantFoldInstOperands(OpC, LHS->getType(), Ops, DL, TLI); + return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL); } } return ConstantExpr::getCompare(Predicate, Ops0, Ops1); } +Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, + Constant *RHS, + const DataLayout &DL) { + assert(Instruction::isBinaryOp(Opcode)); + if (isa(LHS) || isa(RHS)) + if (Constant *C = SymbolicallyEvaluateBinop(Opcode, LHS, RHS, DL)) + return C; + + return ConstantExpr::get(Opcode, LHS, RHS); +} + +Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, + Type *DestTy, const DataLayout &DL) { + switch (Opcode) { + default: + llvm_unreachable("Missing case"); + case Instruction::PtrToInt: + // If the input is a inttoptr, eliminate the pair. This requires knowing + // the width of a pointer, so it can't be done in ConstantExpr::getCast. + if (ConstantExpr *CE = dyn_cast(C)) { + if (CE->getOpcode() == Instruction::IntToPtr) { + Constant *Input = CE->getOperand(0); + unsigned InWidth = Input->getType()->getScalarSizeInBits(); + unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType()); + if (PtrWidth < InWidth) { + Constant *Mask = + ConstantInt::get(CE->getContext(), + APInt::getLowBitsSet(InWidth, PtrWidth)); + Input = ConstantExpr::getAnd(Input, Mask); + } + // Do a zext or trunc to get to the dest size. + return ConstantExpr::getIntegerCast(Input, DestTy, false); + } + } + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::IntToPtr: + // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if + // the int size is >= the ptr size and the address spaces are the same. + // This requires knowing the width of a pointer, so it can't be done in + // ConstantExpr::getCast. + if (ConstantExpr *CE = dyn_cast(C)) { + if (CE->getOpcode() == Instruction::PtrToInt) { + Constant *SrcPtr = CE->getOperand(0); + unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType()); + unsigned MidIntSize = CE->getType()->getScalarSizeInBits(); + + if (MidIntSize >= SrcPtrSize) { + unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace(); + if (SrcAS == DestTy->getPointerAddressSpace()) + return FoldBitCast(CE->getOperand(0), DestTy, DL); + } + } + } + + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::AddrSpaceCast: + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::BitCast: + return FoldBitCast(C, DestTy, DL); + } +} /// Given a constant and a getelementptr constantexpr, return the constant value /// being addressed by the constant expression, or null if something is funny Index: lib/Analysis/InstructionSimplify.cpp =================================================================== --- lib/Analysis/InstructionSimplify.cpp +++ lib/Analysis/InstructionSimplify.cpp @@ -528,11 +528,8 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Add, CLHS->getType(), Ops, - Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -660,11 +657,8 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Sub, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL); // X - undef -> undef // undef - X -> undef @@ -787,11 +781,8 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::FAdd, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -829,11 +820,8 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::FSub, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL); } // fsub X, 0 ==> X @@ -867,11 +855,8 @@ const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::FMul, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -893,11 +878,8 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Mul, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -992,12 +974,9 @@ /// If not, this returns null. static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast(Op0)) { - if (Constant *C1 = dyn_cast(Op1)) { - Constant *Ops[] = { C0, C1 }; - return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, Q.DL, Q.TLI); - } - } + if (Constant *C0 = dyn_cast(Op0)) + if (Constant *C1 = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); bool isSigned = Opcode == Instruction::SDiv; @@ -1157,12 +1136,9 @@ /// If not, this returns null. static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast(Op0)) { - if (Constant *C1 = dyn_cast(Op1)) { - Constant *Ops[] = { C0, C1 }; - return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, Q.DL, Q.TLI); - } - } + if (Constant *C0 = dyn_cast(Op0)) + if (Constant *C1 = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); // X % undef -> undef if (match(Op1, m_Undef())) @@ -1309,12 +1285,9 @@ /// If not, this returns null. static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast(Op0)) { - if (Constant *C1 = dyn_cast(Op1)) { - Constant *Ops[] = { C0, C1 }; - return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, Q.DL, Q.TLI); - } - } + if (Constant *C0 = dyn_cast(Op0)) + if (Constant *C1 = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); // 0 shift by X -> 0 if (match(Op0, m_Zero())) @@ -1558,11 +1531,8 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::And, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -1717,11 +1687,8 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Or, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -1853,11 +1820,8 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast(Op0)) { - if (Constant *CRHS = dyn_cast(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Xor, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -3675,7 +3639,7 @@ static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { if (Constant *C = dyn_cast(Op)) - return ConstantFoldInstOperands(Instruction::Trunc, Ty, C, Q.DL, Q.TLI); + return ConstantFoldCastOperand(Instruction::Trunc, C, Ty, Q.DL); return nullptr; } @@ -3730,11 +3694,8 @@ case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); default: if (Constant *CLHS = dyn_cast(LHS)) - if (Constant *CRHS = dyn_cast(RHS)) { - Constant *COps[] = {CLHS, CRHS}; - return ConstantFoldInstOperands(Opcode, LHS->getType(), COps, Q.DL, - Q.TLI); - } + if (Constant *CRHS = dyn_cast(RHS)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); // If the operation is associative, try some generic simplifications. if (Instruction::isAssociative(Opcode)) Index: lib/CodeGen/MachineFunction.cpp =================================================================== --- lib/CodeGen/MachineFunction.cpp +++ lib/CodeGen/MachineFunction.cpp @@ -895,17 +895,17 @@ // the constant folding APIs to do this so that we get the benefit of // DataLayout. if (isa(A->getType())) - A = ConstantFoldInstOperands(Instruction::PtrToInt, IntTy, - const_cast(A), DL); + A = ConstantFoldCastOperand(Instruction::PtrToInt, + const_cast(A), IntTy, DL); else if (A->getType() != IntTy) - A = ConstantFoldInstOperands(Instruction::BitCast, IntTy, - const_cast(A), DL); + A = ConstantFoldCastOperand(Instruction::BitCast, const_cast(A), + IntTy, DL); if (isa(B->getType())) - B = ConstantFoldInstOperands(Instruction::PtrToInt, IntTy, - const_cast(B), DL); + B = ConstantFoldCastOperand(Instruction::PtrToInt, + const_cast(B), IntTy, DL); else if (B->getType() != IntTy) - B = ConstantFoldInstOperands(Instruction::BitCast, IntTy, - const_cast(B), DL); + B = ConstantFoldCastOperand(Instruction::BitCast, const_cast(B), + IntTy, DL); return A == B; }