diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -4837,35 +4837,47 @@ LLT DstTy = MRI.getType(Src); LLT InsertTy = MRI.getType(InsertSrc); - if (InsertTy.isScalar() && - (DstTy.isScalar() || - (DstTy.isVector() && DstTy.getElementType() == InsertTy))) { - LLT IntDstTy = DstTy; - if (!DstTy.isScalar()) { - IntDstTy = LLT::scalar(DstTy.getSizeInBits()); - Src = MIRBuilder.buildBitcast(IntDstTy, Src).getReg(0); - } + if (InsertTy.isVector() || + (DstTy.isVector() && DstTy.getElementType() != InsertTy)) + return UnableToLegalize; - Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0); - if (Offset != 0) { - auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset); - ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0); - } + const DataLayout &DL = MIRBuilder.getDataLayout(); + if ((DstTy.isPointer() && + DL.isNonIntegralAddressSpace(DstTy.getAddressSpace())) || + (InsertTy.isPointer() && + DL.isNonIntegralAddressSpace(InsertTy.getAddressSpace()))) { + LLVM_DEBUG(dbgs() << "Not casting non-integral address space integer\n"); + return UnableToLegalize; + } - APInt MaskVal = APInt::getBitsSetWithWrap(DstTy.getSizeInBits(), - Offset + InsertTy.getSizeInBits(), - Offset); + LLT IntDstTy = DstTy; - auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal); - auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask); - auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc); + if (!DstTy.isScalar()) { + IntDstTy = LLT::scalar(DstTy.getSizeInBits()); + Src = MIRBuilder.buildCast(IntDstTy, Src).getReg(0); + } - MIRBuilder.buildBitcast(Dst, Or); - MI.eraseFromParent(); - return Legalized; + if (!InsertTy.isScalar()) { + const LLT IntInsertTy = LLT::scalar(InsertTy.getSizeInBits()); + InsertSrc = MIRBuilder.buildPtrToInt(IntInsertTy, InsertSrc).getReg(0); } - return UnableToLegalize; + Register ExtInsSrc = MIRBuilder.buildZExt(IntDstTy, InsertSrc).getReg(0); + if (Offset != 0) { + auto ShiftAmt = MIRBuilder.buildConstant(IntDstTy, Offset); + ExtInsSrc = MIRBuilder.buildShl(IntDstTy, ExtInsSrc, ShiftAmt).getReg(0); + } + + APInt MaskVal = APInt::getBitsSetWithWrap( + DstTy.getSizeInBits(), Offset + InsertTy.getSizeInBits(), Offset); + + auto Mask = MIRBuilder.buildConstant(IntDstTy, MaskVal); + auto MaskedSrc = MIRBuilder.buildAnd(IntDstTy, Src, Mask); + auto Or = MIRBuilder.buildOr(IntDstTy, MaskedSrc, ExtInsSrc); + + MIRBuilder.buildCast(Dst, Or); + MI.eraseFromParent(); + return Legalized; } LegalizerHelper::LegalizeResult diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir --- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir +++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert.mir @@ -1295,8 +1295,8 @@ ; CHECK: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 -65536 ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C1]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[AND]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 @@ -1321,8 +1321,8 @@ ; CHECK: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 -131071 ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C2]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[SHL]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 @@ -1347,8 +1347,8 @@ ; CHECK: [[C2:%[0-9]+]]:_(s32) = G_CONSTANT i32 -16776961 ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C2]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[SHL]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 @@ -1372,8 +1372,8 @@ ; CHECK: [[SHL:%[0-9]+]]:_(s32) = G_SHL [[AND]], [[C1]](s32) ; CHECK: [[AND1:%[0-9]+]]:_(s32) = G_AND [[COPY1]], [[C]] ; CHECK: [[OR:%[0-9]+]]:_(s32) = G_OR [[AND1]], [[SHL]] - ; CHECK: [[BITCAST:%[0-9]+]]:_(s32) = G_BITCAST [[OR]](s32) - ; CHECK: $vgpr0 = COPY [[BITCAST]](s32) + ; CHECK: [[COPY3:%[0-9]+]]:_(s32) = COPY [[OR]](s32) + ; CHECK: $vgpr0 = COPY [[COPY3]](s32) %0:_(s32) = COPY $vgpr0 %1:_(s32) = COPY $vgpr1 %2:_(s16) = G_TRUNC %1 diff --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp @@ -2341,4 +2341,104 @@ // Check EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; } + +TEST_F(GISelMITest, LowerInsert) { + setUp(); + if (!TM) + return; + + // Declare your legalization info + DefineLegalizerInfo(A, { getActionDefinitionsBuilder(G_INSERT).lower(); }); + + LLT S32{LLT::scalar(32)}; + LLT S64{LLT::scalar(64)}; + LLT P0{LLT::pointer(0, 64)}; + LLT P1{LLT::pointer(1, 32)}; + LLT V2S32{LLT::vector(2, 32)}; + + auto TruncS32 = B.buildTrunc(S32, Copies[0]); + auto IntToPtrP0 = B.buildIntToPtr(P0, Copies[0]); + auto IntToPtrP1 = B.buildIntToPtr(P1, TruncS32); + auto BitcastV2S32 = B.buildBitcast(V2S32, Copies[0]); + + auto InsertS64S32 = B.buildInsert(S64, Copies[0], TruncS32, 0); + auto InsertS64P1 = B.buildInsert(S64, Copies[0], IntToPtrP1, 8); + auto InsertP0S32 = B.buildInsert(P0, IntToPtrP0, TruncS32, 16); + auto InsertP0P1 = B.buildInsert(P0, IntToPtrP0, IntToPtrP1, 4); + auto InsertV2S32S32 = B.buildInsert(V2S32, BitcastV2S32, TruncS32, 32); + auto InsertV2S32P1 = B.buildInsert(V2S32, BitcastV2S32, IntToPtrP1, 0); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*InsertS64S32, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*InsertS64P1, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*InsertP0S32, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*InsertP0P1, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*InsertV2S32S32, 0, LLT{})); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::UnableToLegalize, + Helper.lower(*InsertV2S32P1, 0, LLT{})); + + const auto *CheckStr = R"( + CHECK: [[S64:%[0-9]+]]:_(s64) = COPY + CHECK: [[S32:%[0-9]+]]:_(s32) = G_TRUNC [[S64]] + CHECK: [[P0:%[0-9]+]]:_(p0) = G_INTTOPTR [[S64]] + CHECK: [[P1:%[0-9]+]]:_(p1) = G_INTTOPTR [[S32]] + CHECK: [[V2S32:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[S64]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[S32]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[S64]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[ZEXT]]:_ + + CHECK: [[PTRTOINT:%[0-9]+]]:_(s32) = G_PTRTOINT [[P1]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[PTRTOINT]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[S64]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + + CHECK: [[PTRTOINT:%[0-9]+]]:_(s64) = G_PTRTOINT [[P0]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[S32]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[PTRTOINT]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + CHECK: [[INTTOPTR:%[0-9]+]]:_(p0) = G_INTTOPTR [[OR]] + + CHECK: [[PTRTOINT:%[0-9]+]]:_(s64) = G_PTRTOINT [[P0]] + CHECK: [[PTRTOINT1:%[0-9]+]]:_(s32) = G_PTRTOINT [[P1]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[PTRTOINT1]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[PTRTOINT]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + CHECK: [[INTTOPTR:%[0-9]+]]:_(p0) = G_INTTOPTR [[OR]] + + CHECK: [[BITCAST:%[0-9]+]]:_(s64) = G_BITCAST [[V2S32]] + CHECK: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[S32]] + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[ZEXT]]:_, [[C]]:_(s64) + CHECK: [[C:%[0-9]+]]:_(s64) = G_CONSTANT + CHECK: [[AND:%[0-9]+]]:_(s64) = G_AND [[BITCAST]]:_, [[C]]:_ + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[AND]]:_, [[SHL]]:_ + CHECK: [[BITCAST:%[0-9]+]]:_(<2 x s32>) = G_BITCAST [[OR]] + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} } // namespace