diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -29,10 +29,12 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/Support/KnownBits.h" using namespace llvm; @@ -61,6 +63,7 @@ case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; @@ -127,6 +130,7 @@ case Instruction::And: case Instruction::Or: case Instruction::Xor: + case Instruction::Shl: case Instruction::Select: { SmallVector Operands; getRelevantOperands(I, Operands); @@ -137,7 +141,7 @@ // TODO: Can handle more cases here: // 1. shufflevector, extractelement, insertelement // 2. udiv, urem - // 3. shl, lshr, ashr + // 3. lshr, ashr // 4. phi node(and loop handling) // ... return false; @@ -270,6 +274,23 @@ unsigned OrigBitWidth = CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); + // Initialize MinBitWidth for `shl` instructions with the minimum number + // that is greater than shift amount (i.e. shift amount + 1). + // Also normalize MinBitWidth not to be greater than source bitwidth. + for (auto &Itr : InstInfoMap) { + Instruction *I = Itr.first; + if (I->getOpcode() == Instruction::Shl) { + KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL); + const unsigned SrcBitWidth = KnownRHS.getBitWidth(); + unsigned MinBitWidth = + KnownRHS.getMaxValue().uadd_sat(APInt(SrcBitWidth, 1)).getZExtValue(); + MinBitWidth = std::min(MinBitWidth, SrcBitWidth); + if (MinBitWidth >= OrigBitWidth) + return nullptr; + Itr.second.MinBitWidth = MinBitWidth; + } + } + // Calculate minimum allowed bit-width allowed for shrinking the currently // visited truncate's operand. unsigned MinBitWidth = getMinBitWidth(); @@ -356,7 +377,8 @@ case Instruction::Mul: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { + case Instruction::Xor: + case Instruction::Shl: { Value *LHS = getReducedOperand(I->getOperand(0), SclTy); Value *RHS = getReducedOperand(I->getOperand(1), SclTy); Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); diff --git a/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll b/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll --- a/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll +++ b/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll @@ -3,10 +3,9 @@ define i16 @shl_1(i8 %x) { ; CHECK-LABEL: @shl_1( -; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[ZEXT]], 1 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[ZEXT]], 1 +; CHECK-NEXT: ret i16 [[SHL]] ; %zext = zext i8 %x to i32 %shl = shl i32 %zext, 1 @@ -16,10 +15,9 @@ define i16 @shl_15(i8 %x) { ; CHECK-LABEL: @shl_15( -; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[ZEXT]], 15 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[ZEXT]], 15 +; CHECK-NEXT: ret i16 [[SHL]] ; %zext = zext i8 %x to i32 %shl = shl i32 %zext, 15 @@ -61,12 +59,11 @@ define i16 @shl_var_bounded_shift_amount(i8 %x, i8 %y) { ; CHECK-LABEL: @shl_var_bounded_shift_amount( -; CHECK-NEXT: [[ZEXT_X:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[ZEXT_Y:%.*]] = zext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[AND:%.*]] = and i32 [[ZEXT_Y]], 15 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[ZEXT_X]], [[AND]] -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[ZEXT_X:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[ZEXT_Y:%.*]] = zext i8 [[Y:%.*]] to i16 +; CHECK-NEXT: [[AND:%.*]] = and i16 [[ZEXT_Y]], 15 +; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[ZEXT_X]], [[AND]] +; CHECK-NEXT: ret i16 [[SHL]] ; %zext.x = zext i8 %x to i32 %zext.y = zext i8 %y to i32 @@ -78,10 +75,9 @@ define <2 x i16> @shl_vector(<2 x i8> %x) { ; CHECK-LABEL: @shl_vector( -; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32> -; CHECK-NEXT: [[S:%.*]] = shl <2 x i32> [[Z]], -; CHECK-NEXT: [[T:%.*]] = trunc <2 x i32> [[S]] to <2 x i16> -; CHECK-NEXT: ret <2 x i16> [[T]] +; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> +; CHECK-NEXT: [[S:%.*]] = shl <2 x i16> [[Z]], +; CHECK-NEXT: ret <2 x i16> [[S]] ; %z = zext <2 x i8> %x to <2 x i32> %s = shl <2 x i32> %z, @@ -121,10 +117,9 @@ define i16 @shl_nuw(i8 %x) { ; CHECK-LABEL: @shl_nuw( -; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[S:%.*]] = shl nuw i32 [[Z]], 15 -; CHECK-NEXT: [[T:%.*]] = trunc i32 [[S]] to i16 -; CHECK-NEXT: ret i16 [[T]] +; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[S:%.*]] = shl i16 [[Z]], 15 +; CHECK-NEXT: ret i16 [[S]] ; %z = zext i8 %x to i32 %s = shl nuw i32 %z, 15 @@ -134,10 +129,9 @@ define i16 @shl_nsw(i8 %x) { ; CHECK-LABEL: @shl_nsw( -; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[S:%.*]] = shl nsw i32 [[Z]], 15 -; CHECK-NEXT: [[T:%.*]] = trunc i32 [[S]] to i16 -; CHECK-NEXT: ret i16 [[T]] +; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[S:%.*]] = shl i16 [[Z]], 15 +; CHECK-NEXT: ret i16 [[S]] ; %z = zext i8 %x to i32 %s = shl nsw i32 %z, 15