diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -4818,35 +4818,42 @@ LLT DstTy = MRI.getType(Src); LLT InsertTy = MRI.getType(InsertSrc); - if (InsertTy.isScalar() && - (DstTy.isScalar() || - (DstTy.isVector() && DstTy.getElementType() == InsertTy))) { - LLT IntDstTy = DstTy; - if (!DstTy.isScalar()) { - IntDstTy = LLT::scalar(DstTy.getSizeInBits()); - Src = MIRBuilder.buildBitcast(IntDstTy, Src).getReg(0); - } + if (InsertTy.isVector() || + (DstTy.isVector() && DstTy.getElementType() != InsertTy)) + return UnableToLegalize; - Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0); - if (Offset != 0) { - auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset); - ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0); - } + LLT IntDstTy = DstTy; - APInt MaskVal = APInt::getBitsSetWithWrap(DstTy.getSizeInBits(), - Offset + InsertTy.getSizeInBits(), - Offset); + if (!DstTy.isScalar()) { + IntDstTy = LLT::scalar(DstTy.getSizeInBits()); + const unsigned Opcode = + DstTy.isVector() ? TargetOpcode::G_BITCAST : TargetOpcode::G_PTRTOINT; + Src = MIRBuilder.buildInstr(Opcode, {IntDstTy}, {Src}).getReg(0); + } - auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal); - auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask); - auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc); + if (!InsertTy.isScalar()) + InsertSrc = + MIRBuilder + .buildPtrToInt(LLT::scalar(InsertTy.getSizeInBits()), InsertSrc) + .getReg(0); - MIRBuilder.buildBitcast(Dst, Or); - MI.eraseFromParent(); - return Legalized; + Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0); + if (Offset != 0) { + auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset); + ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0); } - return UnableToLegalize; + APInt MaskVal = APInt::getBitsSetWithWrap(DstTy.getSizeInBits(), + Offset + InsertTy.getSizeInBits(), + Offset); + + auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal); + auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask); + auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc); + + MIRBuilder.buildCast(Dst, Or); + MI.eraseFromParent(); + return Legalized; } LegalizerHelper::LegalizeResult diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir --- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir +++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir @@ -1295,8 +1295,8 @@ ; CHECK: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 -65536 ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C1]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[AND]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 @@ -1321,8 +1321,8 @@ ; CHECK: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 -131071 ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C2]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[SHL]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 @@ -1347,8 +1347,8 @@ ; CHECK: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 -16776961 ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C2]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[SHL]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 @@ -1372,8 +1372,8 @@ ; CHECK: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[AND]], [[C1]](s32) ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[SHL]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 diff --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp @@ -2302,4 +2302,104 @@ // Check EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; } + +TEST_F(GISelMITest, LowerInsert) { + setUp(); + if (!TM) + return; + + // Declare your legalization info + DefineLegalizerInfo(A, { getActionDefinitionsBuilder(G_INSERT).lower(); }); + + LLT S32{LLT::scalar(32)}; + LLT S64{LLT::scalar(64)}; + LLT P0{LLT::pointer(0, 64)}; + LLT P1{LLT::pointer(1, 32)}; + LLT V2S32{LLT::vector(2, 32)}; + + auto TruncS32 = B.buildTrunc(S32, Copies[0]); + auto IntToPtrP0 = B.buildIntToPtr(P0, Copies[0]); + auto IntToPtrP1 = B.buildIntToPtr(P1, TruncS32); + auto BitcastV2S32 = B.buildBitcast(V2S32, Copies[0]); + + auto Insert_S64_S32 = B.buildInsert(S64, Copies[0], TruncS32, 0); + auto Insert_S64_P1 = B.buildInsert(S64, Copies[0], IntToPtrP1, 8); + auto Insert_P0_S32 = B.buildInsert(P0, IntToPtrP0, TruncS32, 16); + auto Insert_P0_P1 = B.buildInsert(P0, IntToPtrP0, IntToPtrP1, 4); + auto Insert_V2S32_S32 = B.buildInsert(V2S32, BitcastV2S32, TruncS32, 32); + auto Insert_V2S32_P1 = B.buildInsert(V2S32, BitcastV2S32, IntToPtrP1, 0); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*Insert_S64_S32, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*Insert_S64_P1, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*Insert_P0_S32, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*Insert_P0_P1, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*Insert_V2S32_S32, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::UnableToLegalize, + Helper.lower(*Insert_V2S32_P1, 0, LLT{})); + + const auto *CheckStr = R"( + CHECK: [[S64:%[0-9]+]]:_(s64) = COPY + CHECK: [[S32:%[0-9]+]]:_(s32) = G_TRUNC [[S64]] + CHECK: [[P0:%[0-9]+]]:_(p0) = G_INTTOPTR [[S64]] + CHECK: [[P1:%[0-9]+]]:_(p1) = G_INTTOPTR [[S32]] + CHECK: [[V2S32:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[S64]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[S32]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[S64]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[ZEXT]]:_ + + CHECK: [[PTRTOINT:%[0-9]+]]:_(s32) = G_PTRTOINT [[P1]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[PTRTOINT]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[S64]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + + CHECK: [[PTRTOINT:%[0-9]+]]:_(s64) = G_PTRTOINT [[P0]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[S32]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[PTRTOINT]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + CHECK: [[INTTOPTR:%[0-9]+]]:_(p0) = G_INTTOPTR [[OR]] + + CHECK: [[PTRTOINT:%[0-9]+]]:_(s64) = G_PTRTOINT [[P0]] + CHECK: [[PTRTOINT1:%[0-9]+]]:_(s32) = G_PTRTOINT [[P1]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[PTRTOINT1]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[PTRTOINT]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + CHECK: [[INTTOPTR:%[0-9]+]]:_(p0) = G_INTTOPTR [[OR]] + + CHECK: [[BITCAST:%[0-9]+]]:_(s64) = G_BITCAST [[V2S32]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[S32]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[BITCAST]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + CHECK: [[BITCAST:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[OR]] + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} } // namespace