Index: llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -36,6 +36,25 @@ #define DEBUG_TYPE "aggressive-instcombine" +// This function returns true if Value V is a constant or if it's a type +// extension node. +static bool isConstOrExt(Value *V) { + if (isa(V)) + return true; + + if (Instruction *I = dyn_cast(V)) { + switch(I->getOpcode()) { + case Instruction::ZExt: + case Instruction::SExt: + return true; + default: + return false; + } + } + return false; +} + + /// Given an instruction and a container, it fills all the relevant operands of /// that instruction, with respect to the Trunc expression dag optimizaton. static void getRelevantOperands(Instruction *I, SmallVectorImpl &Ops) { @@ -53,9 +72,22 @@ case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::ICmp: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; + case Instruction::Select: { + Value *Op0 = I->getOperand(0); + Ops.push_back(I->getOperand(1)); + Ops.push_back(I->getOperand(2)); + // In case the condition is a compare instruction, that both of its operands + // are a type extension/truncate or a constant, that can be shrinked without + // loosing information in the compare instruction, add them as well. + if (CmpInst *C = dyn_cast(Op0)) + if (isConstOrExt(C->getOperand(0)) && isConstOrExt(C->getOperand(1))) + Ops.push_back(Op0); + break; + } default: llvm_unreachable("Unreachable!"); } @@ -114,7 +146,9 @@ case Instruction::Mul: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { + case Instruction::Xor: + case Instruction::Select: + case Instruction::ICmp: { SmallVector Operands; getRelevantOperands(I, Operands); for (Value *Operand : Operands) @@ -123,7 +157,7 @@ } default: // TODO: Can handle more cases here: - // 1. select, shufflevector, extractelement, insertelement + // 1. shufflevector, extractelement, insertelement // 2. udiv, urem // 3. shl, lshr, ashr // 4. phi node(and loop handling) @@ -188,14 +222,27 @@ for (auto *Operand : Operands) if (auto *IOp = dyn_cast(Operand)) { - // If we already calculated the minimum bit-width for this valid - // bit-width, or for a smaller valid bit-width, then just keep the - // answer we already calculated. - unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; - if (IOpBitwidth >= ValidBitWidth) - continue; - InstInfoMap[IOp].ValidBitWidth = std::max(ValidBitWidth, IOpBitwidth); + if (isa(I)) { + // Cmp instructions kind of resets the valid bits analysis for its + // operands, as it does not continue with the same calculation chain + // but rather creates a new chain of its own. + switch (IOp->getOpcode()) { + case Instruction::SExt: + case Instruction::ZExt: + InstInfoMap[IOp].ValidBitWidth = + cast(IOp)->getSrcTy()->getScalarSizeInBits(); + break; + } + } else { + // If we already calculated the minimum bit-width for this valid + // bit-width, or for a smaller valid bit-width, then just keep the + // answer we already calculated. + unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; + if (IOpBitwidth >= ValidBitWidth) + continue; + InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; Worklist.push_back(IOp); + } } } unsigned MinBitWidth = InstInfoMap.lookup(cast(Src)).MinBitWidth; @@ -351,6 +398,23 @@ Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); break; } + case Instruction::Select: { + Value *Op0 = I->getOperand(0); + if (ICmpInst *C = dyn_cast(Op0)) + if (isConstOrExt(C->getOperand(0)) && isConstOrExt(C->getOperand(1))) + Op0 = getReducedOperand(Op0, SclTy); + Value *LHS = getReducedOperand(I->getOperand(1), SclTy); + Value *RHS = getReducedOperand(I->getOperand(2), SclTy); + Res = Builder.CreateSelect(Op0, LHS, RHS); + break; + } + case Instruction::ICmp: { + auto ICmp = cast(I); + Value *LHS = getReducedOperand(ICmp->getOperand(0), SclTy); + Value *RHS = getReducedOperand(ICmp->getOperand(1), SclTy); + Res = Builder.CreateICmp(ICmp->getPredicate(), LHS, RHS); + break; + } default: llvm_unreachable("Unhandled instruction"); }