Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -54,15 +54,6 @@ case Instruction::GetElementPtr: llvm_unreachable("Use getGEPCost for GEP operations!"); - case Instruction::BitCast: - assert(OpTy && "Cast instructions must provide the operand type"); - if (Ty == OpTy || (Ty->isPointerTy() && OpTy->isPointerTy())) - // Identity and pointer-to-pointer casts are free. - return TTI::TCC_Free; - - // Otherwise, the default basic cost is used. - return TTI::TCC_Basic; - case Instruction::FDiv: case Instruction::FRem: case Instruction::SDiv: @@ -71,35 +62,11 @@ case Instruction::URem: return TTI::TCC_Expensive; - case Instruction::IntToPtr: { - // An inttoptr cast is free so long as the input is a legal integer type - // which doesn't contain values outside the range of a pointer. - unsigned OpSize = OpTy->getScalarSizeInBits(); - if (DL.isLegalInteger(OpSize) && - OpSize <= DL.getPointerTypeSizeInBits(Ty)) - return TTI::TCC_Free; - - // Otherwise it's not a no-op. - return TTI::TCC_Basic; - } - case Instruction::PtrToInt: { - // A ptrtoint cast is free so long as the result is large enough to store - // the pointer, and a legal integer type. - unsigned DestSize = Ty->getScalarSizeInBits(); - if (DL.isLegalInteger(DestSize) && - DestSize >= DL.getPointerTypeSizeInBits(OpTy)) - return TTI::TCC_Free; - - // Otherwise it's not a no-op. - return TTI::TCC_Basic; - } + case Instruction::BitCast: + case Instruction::IntToPtr: + case Instruction::PtrToInt: case Instruction::Trunc: - // trunc to a native type is free (assuming the target has compare and - // shift-right of the same width). - if (DL.isLegalInteger(DL.getTypeSizeInBits(Ty))) - return TTI::TCC_Free; - - return TTI::TCC_Basic; + return getCastInstrCost(Opcode, Ty, OpTy, nullptr); } } @@ -385,8 +352,50 @@ return 1; } - unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, - const Instruction *I) { return 1; } + unsigned getCastInstrCost(unsigned Opcode, Type *Ty, Type *OpTy, + const Instruction *I) { + switch(Opcode) { + default: + return TTI::TCC_Basic; + case Instruction::BitCast: + assert(OpTy && "Cast instructions must provide the operand type"); + if (Ty == OpTy || (Ty->isPointerTy() && OpTy->isPointerTy())) + // Identity and pointer-to-pointer casts are free. + return TTI::TCC_Free; + + // Otherwise, the default basic cost is used. + return TTI::TCC_Basic; + case Instruction::IntToPtr: { + // An inttoptr cast is free so long as the input is a legal integer type + // which doesn't contain values outside the range of a pointer. + unsigned OpSize = OpTy->getScalarSizeInBits(); + if (DL.isLegalInteger(OpSize) && + OpSize <= DL.getPointerTypeSizeInBits(Ty)) + return TTI::TCC_Free; + + // Otherwise it's not a no-op. + return TTI::TCC_Basic; + } + case Instruction::PtrToInt: { + // A ptrtoint cast is free so long as the result is large enough to store + // the pointer, and a legal integer type. + unsigned DestSize = Ty->getScalarSizeInBits(); + if (DL.isLegalInteger(DestSize) && + DestSize >= DL.getPointerTypeSizeInBits(OpTy)) + return TTI::TCC_Free; + + // Otherwise it's not a no-op. + return TTI::TCC_Basic; + } + case Instruction::Trunc: + // trunc to a native type is free (assuming the target has compare and + // shift-right of the same width). + if (DL.isLegalInteger(DL.getTypeSizeInBits(Ty))) + return TTI::TCC_Free; + + return TTI::TCC_Basic; + } + } unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) { @@ -733,13 +742,14 @@ } unsigned getUserCost(const User *U, ArrayRef Operands) { + T* This = static_cast(this); if (isa(U)) return TTI::TCC_Free; // Model all PHI nodes as free. if (const GEPOperator *GEP = dyn_cast(U)) { - return static_cast(this)->getGEPCost(GEP->getSourceElementType(), - GEP->getPointerOperand(), - Operands.drop_front()); + return This->getGEPCost(GEP->getSourceElementType(), + GEP->getPointerOperand(), + Operands.drop_front()); } if (auto CS = ImmutableCallSite(U)) { @@ -747,12 +757,11 @@ if (!F) { // Just use the called value type. Type *FTy = CS.getCalledValue()->getType()->getPointerElementType(); - return static_cast(this) - ->getCallCost(cast(FTy), CS.arg_size()); + return This->getCallCost(cast(FTy), CS.arg_size()); } SmallVector Arguments(CS.arg_begin(), CS.arg_end()); - return static_cast(this)->getCallCost(F, Arguments); + return This->getCallCost(F, Arguments); } if (const CastInst *CI = dyn_cast(U)) { @@ -762,11 +771,12 @@ if (isa(CI->getOperand(0))) return TTI::TCC_Free; if (isa(CI) || isa(CI) || isa(CI)) - return static_cast(this)->getExtCost(CI, Operands.back()); + return This->getExtCost(CI, Operands.back()); + return This->getCastInstrCost(CI->getOpcode(), U->getType(), + U->getOperand(0)->getType(), CI); } - return static_cast(this)->getOperationCost( - Operator::getOpcode(U), U->getType(), + return This->getOperationCost(Operator::getOpcode(U), U->getType(), U->getNumOperands() == 1 ? U->getOperand(0)->getType() : nullptr); } }; Index: test/Transforms/SimplifyCFG/ARM/select-trunc-i64.ll =================================================================== --- /dev/null +++ test/Transforms/SimplifyCFG/ARM/select-trunc-i64.ll @@ -0,0 +1,25 @@ +; RUN: opt -S -simplifycfg -mtriple=arm < %s | FileCheck %s +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" + +; CHECK-LABEL: select_trunc_i64 +; CHECK-NOT: br +; CHECK: select +; CHECK: select +define arm_aapcscc i32 @select_trunc_i64(i32 %a, i32 %b) { +entry: + %conv = sext i32 %a to i64 + %conv1 = sext i32 %b to i64 + %add = add nsw i64 %conv1, %conv + %cmp = icmp sgt i64 %add, 2147483647 + br i1 %cmp, label %cond.end7, label %cond.false + +cond.false: ; preds = %entry + %0 = icmp sgt i64 %add, -2147483648 + %cond = select i1 %0, i64 %add, i64 -2147483648 + %extract.t = trunc i64 %cond to i32 + br label %cond.end7 + +cond.end7: ; preds = %cond.false, %entry + %cond8.off0 = phi i32 [ 2147483647, %entry ], [ %extract.t, %cond.false ] + ret i32 %cond8.off0 +}