diff --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp @@ -65,6 +65,7 @@ case Instruction::Xor: case Instruction::Shl: case Instruction::LShr: + case Instruction::AShr: Ops.push_back(I->getOperand(0)); Ops.push_back(I->getOperand(1)); break; @@ -133,6 +134,7 @@ case Instruction::Xor: case Instruction::Shl: case Instruction::LShr: + case Instruction::AShr: case Instruction::Select: { SmallVector Operands; getRelevantOperands(I, Operands); @@ -143,8 +145,7 @@ // TODO: Can handle more cases here: // 1. shufflevector, extractelement, insertelement // 2. udiv, urem - // 3. ashr - // 4. phi node(and loop handling) + // 3. phi node(and loop handling) // ... return false; } @@ -278,13 +279,14 @@ // Initialize MinBitWidth for shift instructions with the minimum number // that is greater than shift amount (i.e. shift amount + 1). For `lshr` - // adjust MinBitWidth so that all potentially truncated bits of + // and `ashr` adjust MinBitWidth so that all potentially truncated bits of // the value-to-be-shifted are zeros. // Also normalize MinBitWidth not to be greater than source bitwidth. for (auto &Itr : InstInfoMap) { Instruction *I = Itr.first; if (I->getOpcode() == Instruction::Shl || - I->getOpcode() == Instruction::LShr) { + I->getOpcode() == Instruction::LShr || + I->getOpcode() == Instruction::AShr) { KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL); const unsigned SrcBitWidth = KnownRHS.getBitWidth(); unsigned MinBitWidth = KnownRHS.getMaxValue() @@ -292,7 +294,8 @@ .getLimitedValue(SrcBitWidth); if (MinBitWidth >= OrigBitWidth) return nullptr; - if (I->getOpcode() == Instruction::LShr) { + if (I->getOpcode() == Instruction::LShr || + I->getOpcode() == Instruction::AShr) { KnownBits KnownLHS = computeKnownBits(I->getOperand(0), DL); MinBitWidth = std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); @@ -390,14 +393,18 @@ case Instruction::And: case Instruction::Or: case Instruction::Xor: - case Instruction::Shl: - case Instruction::LShr: { + case Instruction::Shl: { Value *LHS = getReducedOperand(I->getOperand(0), SclTy); Value *RHS = getReducedOperand(I->getOperand(1), SclTy); Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); + break; + } + case Instruction::LShr: + case Instruction::AShr: { + Value *LHS = getReducedOperand(I->getOperand(0), SclTy); + Value *RHS = getReducedOperand(I->getOperand(1), SclTy); // Preserve `exact` flag since truncation doesn't change exactness - if (Opc == Instruction::LShr) - cast(Res)->setIsExact(I->isExact()); + Res = Builder.CreateLShr(LHS, RHS, I->getName(), I->isExact()); break; } case Instruction::Select: { diff --git a/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll b/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll --- a/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll +++ b/llvm/test/Transforms/AggressiveInstCombine/trunc_shifts.ll @@ -420,10 +420,8 @@ define i16 @ashr_15(i16 %x) { ; CHECK-LABEL: @ashr_15( -; CHECK-NEXT: [[ZEXT:%.*]] = zext i16 [[X:%.*]] to i32 -; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[ZEXT]], 15 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ASHR]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[ASHR:%.*]] = lshr i16 [[X:%.*]], 15 +; CHECK-NEXT: ret i16 [[ASHR]] ; %zext = zext i16 %x to i32 %ashr = ashr i32 %zext, 15 @@ -469,14 +467,13 @@ define i16 @ashr_var_bounded_shift_amount(i8 %x, i8 %amt) { ; CHECK-LABEL: @ashr_var_bounded_shift_amount( -; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[ZA:%.*]] = zext i8 [[AMT:%.*]] to i32 -; CHECK-NEXT: [[ZA2:%.*]] = and i32 [[ZA]], 15 -; CHECK-NEXT: [[S:%.*]] = ashr i32 [[Z]], [[ZA2]] -; CHECK-NEXT: [[A:%.*]] = add i32 [[S]], [[Z]] -; CHECK-NEXT: [[S2:%.*]] = ashr i32 [[A]], 2 -; CHECK-NEXT: [[T:%.*]] = trunc i32 [[S2]] to i16 -; CHECK-NEXT: ret i16 [[T]] +; CHECK-NEXT: [[Z:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[ZA:%.*]] = zext i8 [[AMT:%.*]] to i16 +; CHECK-NEXT: [[ZA2:%.*]] = and i16 [[ZA]], 15 +; CHECK-NEXT: [[S:%.*]] = lshr i16 [[Z]], [[ZA2]] +; CHECK-NEXT: [[A:%.*]] = add i16 [[S]], [[Z]] +; CHECK-NEXT: [[S2:%.*]] = lshr i16 [[A]], 2 +; CHECK-NEXT: ret i16 [[S2]] ; %z = zext i8 %x to i32 %za = zext i8 %amt to i32 @@ -509,16 +506,15 @@ define void @ashr_big_dag(i16* %a, i8 %b, i8 %c) { ; CHECK-LABEL: @ashr_big_dag( -; CHECK-NEXT: [[ZEXT1:%.*]] = zext i8 [[B:%.*]] to i32 -; CHECK-NEXT: [[ZEXT2:%.*]] = zext i8 [[C:%.*]] to i32 -; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[ZEXT1]], [[ZEXT2]] -; CHECK-NEXT: [[SFT1:%.*]] = and i32 [[ADD1]], 15 -; CHECK-NEXT: [[SHR1:%.*]] = ashr i32 [[ADD1]], [[SFT1]] -; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[ADD1]], [[SHR1]] -; CHECK-NEXT: [[SFT2:%.*]] = and i32 [[ADD2]], 7 -; CHECK-NEXT: [[SHR2:%.*]] = ashr i32 [[ADD2]], [[SFT2]] -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[SHR2]] to i16 -; CHECK-NEXT: store i16 [[TRUNC]], i16* [[A:%.*]], align 2 +; CHECK-NEXT: [[ZEXT1:%.*]] = zext i8 [[B:%.*]] to i16 +; CHECK-NEXT: [[ZEXT2:%.*]] = zext i8 [[C:%.*]] to i16 +; CHECK-NEXT: [[ADD1:%.*]] = add i16 [[ZEXT1]], [[ZEXT2]] +; CHECK-NEXT: [[SFT1:%.*]] = and i16 [[ADD1]], 15 +; CHECK-NEXT: [[SHR1:%.*]] = lshr i16 [[ADD1]], [[SFT1]] +; CHECK-NEXT: [[ADD2:%.*]] = add i16 [[ADD1]], [[SHR1]] +; CHECK-NEXT: [[SFT2:%.*]] = and i16 [[ADD2]], 7 +; CHECK-NEXT: [[SHR2:%.*]] = lshr i16 [[ADD2]], [[SFT2]] +; CHECK-NEXT: store i16 [[SHR2]], i16* [[A:%.*]], align 2 ; CHECK-NEXT: ret void ; %zext1 = zext i8 %b to i32 @@ -538,10 +534,8 @@ ; CHECK-LABEL: @ashr_smaller_bitwidth( ; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16 ; CHECK-NEXT: [[ASHR:%.*]] = ashr i16 [[ZEXT]], 1 -; CHECK-NEXT: [[ZEXT2:%.*]] = zext i16 [[ASHR]] to i32 -; CHECK-NEXT: [[ASHR2:%.*]] = ashr i32 [[ZEXT2]], 2 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ASHR2]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[ASHR2:%.*]] = lshr i16 [[ASHR]], 2 +; CHECK-NEXT: ret i16 [[ASHR2]] ; %zext = zext i8 %x to i16 %ashr = ashr i16 %zext, 1 @@ -553,12 +547,10 @@ define i16 @ashr_larger_bitwidth(i8 %x) { ; CHECK-LABEL: @ashr_larger_bitwidth( -; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i64 -; CHECK-NEXT: [[ASHR:%.*]] = ashr i64 [[ZEXT]], 1 -; CHECK-NEXT: [[ZEXT2:%.*]] = trunc i64 [[ASHR]] to i32 -; CHECK-NEXT: [[AND:%.*]] = ashr i32 [[ZEXT2]], 2 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[AND]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16 +; CHECK-NEXT: [[ASHR:%.*]] = lshr i16 [[ZEXT]], 1 +; CHECK-NEXT: [[AND:%.*]] = lshr i16 [[ASHR]], 2 +; CHECK-NEXT: ret i16 [[AND]] ; %zext = zext i8 %x to i64 %ashr = ashr i64 %zext, 1 @@ -587,13 +579,12 @@ define <2 x i16> @ashr_vector(<2 x i8> %x) { ; CHECK-LABEL: @ashr_vector( -; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32> -; CHECK-NEXT: [[ZA:%.*]] = and <2 x i32> [[Z]], -; CHECK-NEXT: [[S:%.*]] = ashr <2 x i32> [[Z]], [[ZA]] -; CHECK-NEXT: [[A:%.*]] = add <2 x i32> [[S]], [[Z]] -; CHECK-NEXT: [[S2:%.*]] = ashr <2 x i32> [[A]], -; CHECK-NEXT: [[T:%.*]] = trunc <2 x i32> [[S2]] to <2 x i16> -; CHECK-NEXT: ret <2 x i16> [[T]] +; CHECK-NEXT: [[Z:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i16> +; CHECK-NEXT: [[ZA:%.*]] = and <2 x i16> [[Z]], +; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[Z]], [[ZA]] +; CHECK-NEXT: [[A:%.*]] = add <2 x i16> [[S]], [[Z]] +; CHECK-NEXT: [[S2:%.*]] = lshr <2 x i16> [[A]], +; CHECK-NEXT: ret <2 x i16> [[S2]] ; %z = zext <2 x i8> %x to <2 x i32> %za = and <2 x i32> %z, @@ -648,11 +639,9 @@ define i16 @ashr_exact(i16 %x) { ; CHECK-LABEL: @ashr_exact( -; CHECK-NEXT: [[ZEXT:%.*]] = zext i16 [[X:%.*]] to i32 -; CHECK-NEXT: [[AND:%.*]] = and i32 [[ZEXT]], 32767 -; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[AND]], 15 -; CHECK-NEXT: [[TRUNC:%.*]] = trunc i32 [[ASHR]] to i16 -; CHECK-NEXT: ret i16 [[TRUNC]] +; CHECK-NEXT: [[AND:%.*]] = and i16 [[X:%.*]], 32767 +; CHECK-NEXT: [[ASHR:%.*]] = lshr exact i16 [[AND]], 15 +; CHECK-NEXT: ret i16 [[ASHR]] ; %zext = zext i16 %x to i32 %and = and i32 %zext, 32767