diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h --- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h @@ -369,7 +369,7 @@ LegalizeResult lowerReadWriteRegister(MachineInstr &MI); LegalizeResult lowerSMULH_UMULH(MachineInstr &MI); LegalizeResult lowerSelect(MachineInstr &MI); - + LegalizeResult lowerFunnelShift(MachineInstr &MI); }; /// Helper function that creates a libcall to the given \p Name using the given diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp @@ -3138,6 +3138,10 @@ } case G_SELECT: return lowerSelect(MI); + + case G_FSHL: + case G_FSHR: + return lowerFunnelShift(MI); } } @@ -6226,4 +6230,52 @@ MIRBuilder.buildOr(DstReg, NewOp1, NewOp2); MI.eraseFromParent(); return Legalized; -} \ No newline at end of file +} + +LegalizerHelper::LegalizeResult +LegalizerHelper::lowerFunnelShift(MachineInstr &MI) { + Register DstReg = MI.getOperand(0).getReg(); + Register XReg = MI.getOperand(1).getReg(); + Register YReg = MI.getOperand(2).getReg(); + Register ZReg = MI.getOperand(3).getReg(); + + bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL; + + LLT DstTy = MRI.getType(DstReg); + + if (DstTy.isVector()) + return UnableToLegalize; + + // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW)) + // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW) + + auto Mask = MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits() - 1); + auto BitWidthC = MIRBuilder.buildConstant(DstTy, DstTy.getScalarSizeInBits()); + Register ShAmtReg = MRI.createGenericVirtualRegister(DstTy); + MIRBuilder.buildInstr(TargetOpcode::G_UREM, {ShAmtReg}, + {ZReg, BitWidthC.getReg(0)}); + + Register InvShAmtReg = MRI.createGenericVirtualRegister(DstTy); + MIRBuilder.buildSub(InvShAmtReg, Mask.getReg(0), ShAmtReg); + + auto One = MIRBuilder.buildConstant(DstTy, 1); + + Register ShX = MRI.createGenericVirtualRegister(DstTy); + Register ShY = MRI.createGenericVirtualRegister(DstTy); + if (IsFSHL) { + MIRBuilder.buildShl(ShX, XReg, ShAmtReg); + Register ShY1 = MRI.createGenericVirtualRegister(DstTy); + MIRBuilder.buildLShr(ShY1, YReg, One); + MIRBuilder.buildLShr(ShY, ShY1, InvShAmtReg); + } else { + Register ShX1 = MRI.createGenericVirtualRegister(DstTy); + MIRBuilder.buildShl(ShX1, XReg, One); + MIRBuilder.buildShl(ShX, ShX1, InvShAmtReg); + MIRBuilder.buildLShr(ShY, YReg, ShAmtReg); + } + + MIRBuilder.buildOr(DstReg, ShX, ShY); + + MI.eraseFromParent(); + return Legalized; +} diff --git a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/LegalizerHelperTest.cpp @@ -3179,4 +3179,90 @@ EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; } +// Test lower funnel shift left +TEST_F(AArch64GISelMITest, lowerFunnelShiftLeft) { + setUp(); + if (!TM) + return; + + DefineLegalizerInfo(A, {}); + + LLT S64{LLT::scalar(64)}; + + auto FSHL = B.buildInstr(TargetOpcode::G_FSHL, {S64}, + {Copies[0], Copies[1], Copies[2]}); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + + // Perform Legalization + B.setInsertPt(*EntryMBB, FSHL->getIterator()); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*FSHL, 0, LLT::scalar(64))); + + const auto *CheckStr = R"( + + CHECK: [[COPY0:%[0-9]+]]:_(s64) = COPY + CHECK: [[COPY1:%[0-9]+]]:_(s64) = COPY + CHECK: [[COPY2:%[0-9]+]]:_(s64) = COPY + CHECK: [[SIXTY3:%[0-9]+]]:_(s64) = G_CONSTANT i64 63 + CHECK: [[SIXTY4:%[0-9]+]]:_(s64) = G_CONSTANT i64 64 + CHECK: [[UREM:%[0-9]+]]:_(s64) = G_UREM [[COPY2]]:_, [[SIXTY4]]:_ + CHECK: [[SUB:%[0-9]+]]:_(s64) = G_SUB [[SIXTY3]]:_, [[UREM]]:_ + CHECK: [[ONE:%[0-9]+]]:_(s64) = G_CONSTANT i64 1 + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[COPY0]]:_, [[UREM]]:_(s64) + CHECK: [[LSHR:%[0-9]+]]:_(s64) = G_LSHR [[COPY1]]:_, [[ONE]]:_(s64) + CHECK: [[LSHR2:%[0-9]+]]:_(s64) = G_LSHR [[LSHR]]:_, [[SUB]]:_(s64) + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[SHL]]:_, [[LSHR2]]:_ + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + +// Test lower funnel shift rigth +TEST_F(AArch64GISelMITest, lowerFunnelShiftRight) { + setUp(); + if (!TM) + return; + + DefineLegalizerInfo(A, {}); + + LLT S64{LLT::scalar(64)}; + + auto FSHR = B.buildInstr(TargetOpcode::G_FSHR, {S64}, + {Copies[0], Copies[1], Copies[2]}); + + AInfo Info(MF->getSubtarget()); + DummyGISelObserver Observer; + LegalizerHelper Helper(*MF, Info, Observer, B); + + // Perform Legalization + B.setInsertPt(*EntryMBB, FSHR->getIterator()); + + EXPECT_EQ(LegalizerHelper::LegalizeResult::Legalized, + Helper.lower(*FSHR, 0, LLT::scalar(64))); + + const auto *CheckStr = R"( + + CHECK: [[COPY0:%[0-9]+]]:_(s64) = COPY + CHECK: [[COPY1:%[0-9]+]]:_(s64) = COPY + CHECK: [[COPY2:%[0-9]+]]:_(s64) = COPY + CHECK: [[SIXTY3:%[0-9]+]]:_(s64) = G_CONSTANT i64 63 + CHECK: [[SIXTY4:%[0-9]+]]:_(s64) = G_CONSTANT i64 64 + CHECK: [[UREM:%[0-9]+]]:_(s64) = G_UREM [[COPY2]]:_, [[SIXTY4]]:_ + CHECK: [[SUB:%[0-9]+]]:_(s64) = G_SUB [[SIXTY3]]:_, [[UREM]]:_ + CHECK: [[ONE:%[0-9]+]]:_(s64) = G_CONSTANT i64 1 + CHECK: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[COPY0]]:_, [[ONE]]:_(s64) + CHECK: [[SHL2:%[0-9]+]]:_(s64) = G_SHL [[SHL]]:_, [[SUB]]:_(s64) + CHECK: [[LSHR:%[0-9]+]]:_(s64) = G_LSHR [[COPY1]]:_, [[UREM]]:_(s64) + CHECK: [[OR:%[0-9]+]]:_(s64) = G_OR [[SHL2]]:_, [[LSHR]]:_ + )"; + + // Check + EXPECT_TRUE(CheckMachineFunction(*MF, CheckStr)) << *MF; +} + } // namespace