Index: llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h +++ llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h @@ -74,6 +74,9 @@ /// precision, ignoring the unused bits). LegalizeResult widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy); + /// Legalize an instruction by replacing the value type + LegalizeResult bitcast(MachineInstr &MI, unsigned TypeIdx, LLT Ty); + /// Legalize an instruction by splitting it into simpler parts, hopefully /// understood by the target. LegalizeResult lower(MachineInstr &MI, unsigned TypeIdx, LLT Ty); @@ -128,6 +131,14 @@ /// original vector type, and replacing the vreg of the operand in place. void moreElementsVectorSrc(MachineInstr &MI, LLT MoreTy, unsigned OpIdx); + /// Legalize a single operand \p OpIdx of the machine instruction \p MI as a + /// use by inserting a G_BITCAST to \p CastTy + void bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx); + + /// Legalize a single operand \p OpIdx of the machine instruction \p MI as a + /// def by inserting a G_BITCAST from \p CastTy + void bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx); + private: LegalizeResult widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx, LLT WideTy); Index: llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -68,6 +68,9 @@ /// the first two results. MoreElements, + /// Perform the operation on a different, but equivalently sized type. + Bitcast, + /// The operation itself must be expressed in terms of simpler actions on /// this target. E.g. a SREM replaced by an SDIV and subtraction. Lower, Index: llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -120,6 +120,9 @@ case WidenScalar: LLVM_DEBUG(dbgs() << ".. Widen scalar\n"); return widenScalar(MI, Step.TypeIdx, Step.NewType); + case Bitcast: + LLVM_DEBUG(dbgs() << ".. Bitcast type\n"); + return bitcast(MI, Step.TypeIdx, Step.NewType); case Lower: LLVM_DEBUG(dbgs() << ".. Lower\n"); return lower(MI, Step.TypeIdx, Step.NewType); @@ -1251,6 +1254,19 @@ MO.setReg(MoreReg); } +void LegalizerHelper::bitcastSrc(MachineInstr &MI, LLT CastTy, unsigned OpIdx) { + MachineOperand &Op = MI.getOperand(OpIdx); + Op.setReg(MIRBuilder.buildBitcast(CastTy, Op).getReg(0)); +} + +void LegalizerHelper::bitcastDst(MachineInstr &MI, LLT CastTy, unsigned OpIdx) { + MachineOperand &MO = MI.getOperand(OpIdx); + Register CastDst = MRI.createGenericVirtualRegister(CastTy); + MIRBuilder.setInsertPt(MIRBuilder.getMBB(), ++MIRBuilder.getInsertPt()); + MIRBuilder.buildBitcast(MO, CastDst); + MO.setReg(CastDst); +} + LegalizerHelper::LegalizeResult LegalizerHelper::widenScalarMergeValues(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) { @@ -2104,6 +2120,61 @@ return UnableToLegalize; } +LegalizerHelper::LegalizeResult +LegalizerHelper::bitcast(MachineInstr &MI, unsigned TypeIdx, LLT CastTy) { + MIRBuilder.setInstr(MI); + + switch (MI.getOpcode()) { + case TargetOpcode::G_LOAD: { + if (TypeIdx != 0) + return UnableToLegalize; + + Observer.changingInstr(MI); + bitcastDst(MI, CastTy, 0); + Observer.changedInstr(MI); + return Legalized; + } + case TargetOpcode::G_STORE: { + if (TypeIdx != 0) + return UnableToLegalize; + + Observer.changingInstr(MI); + bitcastSrc(MI, CastTy, 0); + Observer.changedInstr(MI); + return Legalized; + } + case TargetOpcode::G_SELECT: { + if (TypeIdx != 0) + return UnableToLegalize; + + if (MRI.getType(MI.getOperand(1).getReg()).isVector()) { + LLVM_DEBUG( + dbgs() << "bitcast action not implemented for vector select\n"); + return UnableToLegalize; + } + + Observer.changingInstr(MI); + bitcastSrc(MI, CastTy, 2); + bitcastSrc(MI, CastTy, 3); + bitcastDst(MI, CastTy, 0); + Observer.changedInstr(MI); + return Legalized; + } + case TargetOpcode::G_AND: + case TargetOpcode::G_OR: + case TargetOpcode::G_XOR: { + Observer.changingInstr(MI); + bitcastSrc(MI, CastTy, 1); + bitcastSrc(MI, CastTy, 2); + bitcastDst(MI, CastTy, 0); + Observer.changedInstr(MI); + return Legalized; + } + default: + return UnableToLegalize; + } +} + LegalizerHelper::LegalizeResult LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT Ty) { using namespace TargetOpcode; Index: llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp +++ llvm/lib/CodeGen/GlobalISel/LegalizerInfo.cpp @@ -59,6 +59,9 @@ case MoreElements: OS << "MoreElements"; break; + case Bitcast: + OS << "Bitcast"; + break; case Lower: OS << "Lower"; break; @@ -173,6 +176,9 @@ return true; } + case Bitcast: { + return OldTy != NewTy && OldTy.getSizeInBits() == NewTy.getSizeInBits(); + } default: return true; } @@ -575,6 +581,7 @@ LegalizeAction Action = Vec[VecIdx].second; switch (Action) { case Legal: + case Bitcast: case Lower: case Libcall: case Custom: Index: llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp =================================================================== --- llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp +++ llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp @@ -2565,4 +2565,161 @@ EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; } +TEST_F(AArch64GISelMITest, BitcastLoad) { + setUp(); + if (!TM) + return; + + LLT P0 = LLT::pointer(0, 64); + LLT S32 = LLT::scalar(32); + LLT V4S8 = LLT::vector(4, 8); + auto Ptr = B.buildUndef(P0); + + DefineLegalizerInfo(A, {}); + + MachineMemOperand *MMO = B.getMF().getMachineMemOperand( + MachinePointerInfo(), MachineMemOperand::MOLoad, 4, 4); + auto Load = B.buildLoad(V4S8, Ptr, *MMO); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.bitcast(*Load, 0, S32)); + + auto CheckStr = R"( + CHECK: [[PTR:%[0-9]+]]:_(p0) = G_IMPLICIT_DEF + CHECK: [[LOAD:%[0-9]+]]:_(s32) = G_LOAD + CHECK: [[CAST:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[LOAD]] + + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + +TEST_F(AArch64GISelMITest, BitcastStore) { + setUp(); + if (!TM) + return; + + LLT P0 = LLT::pointer(0, 64); + LLT S32 = LLT::scalar(32); + LLT V4S8 = LLT::vector(4, 8); + auto Ptr = B.buildUndef(P0); + + DefineLegalizerInfo(A, {}); + + MachineMemOperand *MMO = B.getMF().getMachineMemOperand( + MachinePointerInfo(), MachineMemOperand::MOStore, 4, 4); + auto Val = B.buildUndef(V4S8); + auto Store = B.buildStore(Val, Ptr, *MMO); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.bitcast(*Store, 0, S32)); + + auto CheckStr = R"( + CHECK: [[VAL:%[0-9]+]]:_(<4 x s8>) = G_IMPLICIT_DEF + CHECK: [[CAST:%[0-9]+]]:_(s32) = G_BITCAST [[VAL]] + CHECK: G_STORE [[CAST]] + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + +TEST_F(AArch64GISelMITest, BitcastSelect) { + setUp(); + if (!TM) + return; + + LLT S1 = LLT::scalar(1); + LLT S32 = LLT::scalar(32); + LLT V4S8 = LLT::vector(4, 8); + + DefineLegalizerInfo(A, {}); + + auto Cond = B.buildUndef(S1); + auto Val0 = B.buildConstant(V4S8, 123); + auto Val1 = B.buildConstant(V4S8, 99); + + auto Select = B.buildSelect(V4S8, Cond, Val0, Val1); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.bitcast(*Select, 0, S32)); + + auto CheckStr = R"( + CHECK: [[VAL0:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR + CHECK: [[VAL1:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR + CHECK: [[CAST0:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]] + CHECK: [[CAST1:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]] + CHECK: [[SELECT:%[0-9]+]]:_(s32) = G_SELECT %{{[0-9]+}}:_(s1), [[CAST0]]:_, [[CAST1]]:_ + CHECK: [[CAST2:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[SELECT]] + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; + + // Doesn't make sense + auto VCond = B.buildUndef(LLT::vector(4, 1)); + auto VSelect = B.buildSelect(V4S8, VCond, Val0, Val1); + EXPECT_EQ(LegalizerHelper::LegalizeResult::UnableToLegalize, + Helper.bitcast(*VSelect, 0, S32)); + EXPECT_EQ(LegalizerHelper::LegalizeResult::UnableToLegalize, + Helper.bitcast(*VSelect, 1, LLT::scalar(4))); +} + +TEST_F(AArch64GISelMITest, BitcastBitOps) { + setUp(); + if (!TM) + return; + + LLT S32 = LLT::scalar(32); + LLT V4S8 = LLT::vector(4, 8); + + DefineLegalizerInfo(A, {}); + + auto Val0 = B.buildConstant(V4S8, 123); + auto Val1 = B.buildConstant(V4S8, 99); + auto And = B.buildAnd(V4S8, Val0, Val1); + auto Or = B.buildOr(V4S8, Val0, Val1); + auto Xor = B.buildXor(V4S8, Val0, Val1); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.bitcast(*And, 0, S32)); + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.bitcast(*Or, 0, S32)); + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.bitcast(*Xor, 0, S32)); + + auto CheckStr = R"( + CHECK: [[VAL0:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR + CHECK: [[VAL1:%[0-9]+]]:_(<4 x s8>) = G_BUILD_VECTOR + CHECK: [[CAST0:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]] + CHECK: [[CAST1:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]] + CHECK: [[AND:%[0-9]+]]:_(s32) = G_AND [[CAST0]]:_, [[CAST1]]:_ + CHECK: [[CAST_AND:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[AND]] + CHECK: [[CAST2:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]] + CHECK: [[CAST3:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]] + CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[CAST2]]:_, [[CAST3]]:_ + CHECK: [[CAST_OR:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[OR]] + CHECK: [[CAST4:%[0-9]+]]:_(s32) = G_BITCAST [[VAL0]] + CHECK: [[CAST5:%[0-9]+]]:_(s32) = G_BITCAST [[VAL1]] + CHECK: [[XOR:%[0-9]+]]:_(s32) = G_XOR [[CAST4]]:_, [[CAST5]]:_ + CHECK: [[CAST_XOR:%[0-9]+]]:_(<4 x s8>) = G_BITCAST [[XOR]] + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + } // namespace