Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1620,6 +1620,30 @@ /// Given an OR instruction, check to see if this is a bswap or bitreverse /// idiom. If so, insert the new intrinsic and return it. Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + // Look through zero extends. + if (Instruction *Ext = dyn_cast(Op0)) + Op0 = Ext->getOperand(0); + + if (Instruction *Ext = dyn_cast(Op1)) + Op1 = Ext->getOperand(0); + + // (A | B) | C and A | (B | C) -> bswap if possible. + bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || + match(Op1, m_Or(m_Value(), m_Value())); + + // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. + bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && + match(Op1, m_LogicalShift(m_Value(), m_Value())); + + // (A & B) | (C & D) -> bswap if possible. + bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && + match(Op1, m_And(m_Value(), m_Value())); + + if (!OrOfOrs && !OrOfShifts && !OrOfAnds) + return nullptr; + SmallVector Insts; if (!recognizeBitReverseOrBSwapIdiom(&I, true, false, Insts)) return nullptr; @@ -2162,23 +2186,13 @@ return NV; } + // Given an OR instruction, check to see if this is a bswap. + if (Instruction *BSwap = MatchBSwapOrBitReverse(I)) + return BSwap; + Value *A = nullptr, *B = nullptr; ConstantInt *C1 = nullptr, *C2 = nullptr; - // (A | B) | C and A | (B | C) -> bswap if possible. - bool OrOfOrs = match(Op0, m_Or(m_Value(), m_Value())) || - match(Op1, m_Or(m_Value(), m_Value())); - // (A >> B) | (C << D) and (A << B) | (B >> C) -> bswap if possible. - bool OrOfShifts = match(Op0, m_LogicalShift(m_Value(), m_Value())) && - match(Op1, m_LogicalShift(m_Value(), m_Value())); - // (A & B) | (C & D) -> bswap if possible. - bool OrOfAnds = match(Op0, m_And(m_Value(), m_Value())) && - match(Op1, m_And(m_Value(), m_Value())); - - if (OrOfOrs || OrOfShifts || OrOfAnds) - if (Instruction *BSwap = MatchBSwapOrBitReverse(I)) - return BSwap; - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 if (Op0->hasOneUse() && match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && Index: lib/Transforms/Utils/Local.cpp =================================================================== --- lib/Transforms/Utils/Local.cpp +++ lib/Transforms/Utils/Local.cpp @@ -1773,7 +1773,26 @@ // If the AndMask is zero for this bit, clear the bit. if ((AndMask & Bit) == 0) Result->Provenance[i] = BitPart::Unset; + return Result; + } + + // If this is a zext instruction zero extend the result. + if (I->getOpcode() == Instruction::ZExt) { + auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, + MatchBitReversals, BPS); + if (!Res) + return Result; + Result = Res; + // Try and merge the two together. + Result = BitPart(Res->Provider, BitWidth); + + auto NarrowBitWidth = + cast(cast(I)->getSrcTy())->getBitWidth(); + for (unsigned i = 0; i < NarrowBitWidth; ++i) + Result->Provenance[i] = Res->Provenance[i]; + for (unsigned i = NarrowBitWidth; i < BitWidth; ++i) + Result->Provenance[i] = BitPart::Unset; return Result; } } @@ -1816,6 +1835,15 @@ return false; // Can't do vectors or integers > 128 bits. unsigned BW = ITy->getBitWidth(); + unsigned DemandedBW = BW; + IntegerType *DemandedTy = ITy; + if (I->hasOneUse()) { + if (TruncInst *Trunc = dyn_cast(I->user_back())) { + DemandedTy = cast(Trunc->getType()); + DemandedBW = DemandedTy->getBitWidth(); + } + } + // Try to find all the pieces corresponding to the bswap. std::map> BPS; auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS); @@ -1825,11 +1853,11 @@ // Now, is the bit permutation correct for a bswap or a bitreverse? We can // only byteswap values with an even number of bytes. - bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true; - for (unsigned i = 0; i < BW; ++i) { - OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW); + bool OKForBSwap = DemandedBW % 16 == 0, OKForBitReverse = true; + for (unsigned i = 0; i < DemandedBW; ++i) { + OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, DemandedBW); OKForBitReverse &= - bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW); + bitTransformIsCorrectForBitReverse(BitProvenance[i], i, DemandedBW); } Intrinsic::ID Intrin; @@ -1840,6 +1868,23 @@ else return false; + if (ITy != DemandedTy) { + Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, DemandedTy); + Value *Provider = Res->Provider; + IntegerType *ProviderTy = cast(Provider->getType()); + // We may need to truncate the provider. + if (DemandedTy != ProviderTy) { + auto *Trunc = CastInst::Create(Instruction::Trunc, Provider, DemandedTy, "trunc", I); + InsertedInsts.push_back(Trunc); + Provider = Trunc; + } + auto *CI = CallInst::Create(F, Provider, "rev", I); + InsertedInsts.push_back(CI); + auto *ExtInst = CastInst::Create(Instruction::ZExt, CI, ITy, "zext", I); + InsertedInsts.push_back(ExtInst); + return true; + } + Function *F = Intrinsic::getDeclaration(I->getModule(), Intrin, ITy); InsertedInsts.push_back(CallInst::Create(F, Res->Provider, "rev", I)); return true; Index: test/Transforms/InstCombine/bswap.ll =================================================================== --- test/Transforms/InstCombine/bswap.ll +++ test/Transforms/InstCombine/bswap.ll @@ -97,3 +97,29 @@ %or6 = or i32 %shl3, %shr5 ret i32 %or6 } + +; CHECK-LABEL: @test8 +; CHECK: call i16 @llvm.bswap.i16(i16 %a) +define i16 @test8(i16 %a) { +entry: + %conv = zext i16 %a to i32 + %shr = lshr i16 %a, 8 + %shl = shl i32 %conv, 8 + %conv1 = zext i16 %shr to i32 + %or = or i32 %conv1, %shl + %conv2 = trunc i32 %or to i16 + ret i16 %conv2 +} + +; CHECK-LABEL: @test9 +; CHECK: trunc i32 %a to i16 +; CHECK: call i16 @llvm.bswap.i16(i16 %trunc) +define i16 @test9(i32 %a) { + %shr1 = lshr i32 %a, 8 + %and1 = and i32 %shr1, 255 + %and2 = shl i32 %a, 8 + %shl1 = and i32 %and2, 65280 + %or = or i32 %and1, %shl1 + %conv = trunc i32 %or to i16 + ret i16 %conv +}