Index: llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -36,6 +36,24 @@ #define DEBUG_TYPE "aggressive-instcombine" +// This function returns true if Value V is a constant or if it's a type +// extension node. +static bool isConstOrExt(Value *V) { + if (isa(V)) + return true; + + if (Instruction *I = dyn_cast(V)) { + switch(I->getOpcode()) { + case Instruction::ZExt: + case Instruction::SExt: + return true; + default: + return false; + } + } + return false; +} + /// Given an instruction and a container, it fills all the relevant operands of /// that instruction, with respect to the Trunc expression dag optimizaton. static void getRelevantOperands(Instruction *I, SmallVectorImpl &Ops) { @@ -53,13 +71,22 @@ case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::ICmp: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; - case Instruction::Select: + case Instruction::Select: { + Value *Op0 = I->getOperand(0); Ops.push_back(I->getOperand(1)); Ops.push_back(I->getOperand(2)); + // In case the condition is a compare instruction, that both of its operands + // are a type extension/truncate or a constant, that can be shrinked without + // loosing information in the compare instruction, add them as well. + if (CmpInst *C = dyn_cast(Op0)) + if (isConstOrExt(C->getOperand(0)) && isConstOrExt(C->getOperand(1))) + Ops.push_back(Op0); break; + } default: llvm_unreachable("Unreachable!"); } @@ -119,7 +146,8 @@ case Instruction::And: case Instruction::Or: case Instruction::Xor: - case Instruction::Select: { + case Instruction::Select: + case Instruction::ICmp: { SmallVector Operands; getRelevantOperands(I, Operands); for (Value *Operand : Operands) @@ -139,6 +167,18 @@ return true; } +// Get the minimum number of bits needed for the given constant. +static unsigned getConstMinBitWidth(bool IsSigned, ConstantInt *C) { + // If the const value is signed and negative, count the leading ones. + APInt Val = C->getValue(); + if (IsSigned && Val.isNegative()) + return Val.getBitWidth() - Val.countLeadingOnes() + 1; + + // Otherwise, count leading zeroes. + auto MinBits = Val.getBitWidth() - Val.countLeadingZeros(); + return IsSigned ? MinBits + 1 : MinBits; +} + unsigned TruncInstCombine::getMinBitWidth() { SmallVector Worklist; SmallVector Stack; @@ -180,6 +220,13 @@ if (auto *IOp = dyn_cast(Operand)) Info.MinBitWidth = std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); + else if (auto *C = dyn_cast(Operand)) { + // In case of Cmp instruction, make sure the constant can be truncated + // without losing information. + if (CmpInst *Cmp = dyn_cast(I)) + Info.MinBitWidth = std::max( + Info.MinBitWidth, getConstMinBitWidth(Cmp->isSigned(), C)); + } continue; } @@ -193,14 +240,27 @@ for (auto *Operand : Operands) if (auto *IOp = dyn_cast(Operand)) { - // If we already calculated the minimum bit-width for this valid - // bit-width, or for a smaller valid bit-width, then just keep the - // answer we already calculated. - unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; - if (IOpBitwidth >= ValidBitWidth) - continue; - InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; - Worklist.push_back(IOp); + if (isa(I)) { + // Cmp instructions kind of resets the valid bits analysis for its + // operands, as it does not continue with the same calculation chain + // but rather creates a new chain of its own. + switch (IOp->getOpcode()) { + case Instruction::SExt: + case Instruction::ZExt: + InstInfoMap[IOp].ValidBitWidth = + cast(IOp)->getSrcTy()->getScalarSizeInBits(); + break; + } + } else { + // If we already calculated the minimum bit-width for this valid + // bit-width, or for a smaller valid bit-width, then just keep the + // answer we already calculated. + unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; + if (IOpBitwidth >= ValidBitWidth) + continue; + InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; + Worklist.push_back(IOp); + } } } unsigned MinBitWidth = InstInfoMap.lookup(cast(Src)).MinBitWidth; @@ -358,11 +418,21 @@ } case Instruction::Select: { Value *Op0 = I->getOperand(0); + if (ICmpInst *C = dyn_cast(Op0)) + if (isConstOrExt(C->getOperand(0)) && isConstOrExt(C->getOperand(1))) + Op0 = getReducedOperand(Op0, SclTy); Value *LHS = getReducedOperand(I->getOperand(1), SclTy); Value *RHS = getReducedOperand(I->getOperand(2), SclTy); Res = Builder.CreateSelect(Op0, LHS, RHS); break; } + case Instruction::ICmp: { + auto ICmp = cast(I); + Value *LHS = getReducedOperand(ICmp->getOperand(0), SclTy); + Value *RHS = getReducedOperand(ICmp->getOperand(1), SclTy); + Res = Builder.CreateICmp(ICmp->getPredicate(), LHS, RHS); + break; + } default: llvm_unreachable("Unhandled instruction"); } Index: llvm/test/Transforms/AggressiveInstCombine/trunc_select_cmp.ll =================================================================== --- llvm/test/Transforms/AggressiveInstCombine/trunc_select_cmp.ll +++ llvm/test/Transforms/AggressiveInstCombine/trunc_select_cmp.ll @@ -5,11 +5,10 @@ define dso_local i16 @cmp_select_sext_const(i8 %a) { ; CHECK-LABEL: @cmp_select_sext_const( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], 109 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 109, i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], 109 +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 109, i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = sext i8 %a to i32 @@ -22,12 +21,11 @@ define dso_local i16 @cmp_select_sext(i8 %a, i8 %b) { ; CHECK-LABEL: @cmp_select_sext( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CONV2:%.*]] = sext i8 [[B:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], [[CONV2]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[CONV2]], i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CONV2:%.*]] = sext i8 [[B:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], [[CONV2]] +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 [[CONV2]], i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = sext i8 %a to i32 @@ -41,12 +39,11 @@ define dso_local i16 @cmp_select_zext(i8 %a, i8 %b) { ; CHECK-LABEL: @cmp_select_zext( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CONV2:%.*]] = zext i8 [[B:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], [[CONV2]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[CONV2]], i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CONV2:%.*]] = zext i8 [[B:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], [[CONV2]] +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 [[CONV2]], i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = zext i8 %a to i32 @@ -60,12 +57,11 @@ define dso_local i16 @cmp_select_zext_sext(i8 %a, i8 %b) { ; CHECK-LABEL: @cmp_select_zext_sext( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CONV2:%.*]] = sext i8 [[B:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], [[CONV2]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[CONV2]], i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CONV2:%.*]] = sext i8 [[B:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], [[CONV2]] +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 [[CONV2]], i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = zext i8 %a to i32 @@ -79,12 +75,10 @@ define dso_local i16 @cmp_select_zext_sext_diffOrigTy(i8 %a, i16 %b) { ; CHECK-LABEL: @cmp_select_zext_sext_diffOrigTy( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CONV2:%.*]] = sext i16 [[B:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], [[CONV2]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[CONV2]], i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], [[B:%.*]] +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 [[B]], i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = zext i8 %a to i32 @@ -98,12 +92,11 @@ define dso_local i16 @my_abs_sext(i8 %a) { ; CHECK-LABEL: @my_abs_sext( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 0, [[CONV]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[SUB]], i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = sext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], 0 +; CHECK-NEXT: [[SUB:%.*]] = sub i16 0, [[CONV]] +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 [[SUB]], i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = sext i8 %a to i32 @@ -117,12 +110,11 @@ define dso_local i16 @my_abs_zext(i8 %a) { ; CHECK-LABEL: @my_abs_zext( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[CONV]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 0, [[CONV]] -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 [[SUB]], i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i16 [[CONV]], 0 +; CHECK-NEXT: [[SUB:%.*]] = sub i16 0, [[CONV]] +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 [[SUB]], i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; entry: %conv = zext i8 %a to i32 @@ -182,11 +174,10 @@ define i16 @cmp_select_unsigned_const_i16Const(i8 %a) { ; CHECK-LABEL: @cmp_select_unsigned_const_i16Const( -; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i32 -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[CONV]], 32768 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 32768, i32 [[CONV]] -; CHECK-NEXT: [[CONV4:%.*]] = trunc i32 [[COND]] to i16 -; CHECK-NEXT: ret i16 [[CONV4]] +; CHECK-NEXT: [[CONV:%.*]] = zext i8 [[A:%.*]] to i16 +; CHECK-NEXT: [[CMP:%.*]] = icmp ult i16 [[CONV]], -32768 +; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i16 -32768, i16 [[CONV]] +; CHECK-NEXT: ret i16 [[COND]] ; %conv = zext i8 %a to i32 %cmp = icmp ult i32 %conv, 32768