Index: llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h +++ llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h @@ -203,6 +203,20 @@ Predicate all(Predicate P0, Predicate P1, Args... args) { return all(all(P0, P1), args...); } + +/// True iff P0 or P1 are true. +template +Predicate any(Predicate P0, Predicate P1) { + return [=](const LegalityQuery &Query) { + return P0(Query) || P1(Query); + }; +} +/// True iff any given predicates are true. +template +Predicate any(Predicate P0, Predicate P1, Args... args) { + return any(any(P0, P1), args...); +} + /// True iff the given type index is the specified types. LegalityPredicate typeIs(unsigned TypeIdx, LLT TypesInit); /// True iff the given type index is one of the specified types. @@ -228,13 +242,16 @@ /// space. LegalityPredicate isPointer(unsigned TypeIdx, unsigned AddrSpace); +/// True if the type index is a vector with element type \p EltTy +LegalityPredicate elementTypeIs(unsigned TypeIdx, LLT EltTy); + /// True iff the specified type index is a scalar that's narrower than the given /// size. -LegalityPredicate narrowerThan(unsigned TypeIdx, unsigned Size); +LegalityPredicate scalarNarrowerThan(unsigned TypeIdx, unsigned Size); /// True iff the specified type index is a scalar that's wider than the given /// size. -LegalityPredicate widerThan(unsigned TypeIdx, unsigned Size); +LegalityPredicate scalarWiderThan(unsigned TypeIdx, unsigned Size); /// True iff the specified type index is a scalar or vector with an element type /// that's narrower than the given size. @@ -257,6 +274,15 @@ /// True iff the specified type indices are both the same bit size. LegalityPredicate sameSize(unsigned TypeIdx0, unsigned TypeIdx1); + +/// True iff the first type index has a larger total bit size than second type +/// index. +LegalityPredicate largerThan(unsigned TypeIdx0, unsigned TypeIdx1); + +/// True iff the first type index has a smaller total bit size than second type +/// index. +LegalityPredicate smallerThan(unsigned TypeIdx0, unsigned TypeIdx1); + /// True iff the specified MMO index has a size that is not a power of 2 LegalityPredicate memSizeInBytesNotPow2(unsigned MMOIdx); /// True iff the specified type index is a vector whose element count is not a @@ -774,7 +800,7 @@ using namespace LegalityPredicates; using namespace LegalizeMutations; return actionIf(LegalizeAction::WidenScalar, - narrowerThan(TypeIdx, Ty.getSizeInBits()), + scalarNarrowerThan(TypeIdx, Ty.getSizeInBits()), changeTo(typeIdx(TypeIdx), Ty)); } @@ -792,7 +818,7 @@ using namespace LegalityPredicates; using namespace LegalizeMutations; return actionIf(LegalizeAction::NarrowScalar, - widerThan(TypeIdx, Ty.getSizeInBits()), + scalarWiderThan(TypeIdx, Ty.getSizeInBits()), changeTo(typeIdx(TypeIdx), Ty)); } @@ -806,7 +832,7 @@ return actionIf( LegalizeAction::NarrowScalar, [=](const LegalityQuery &Query) { - return widerThan(TypeIdx, Ty.getSizeInBits()) && Predicate(Query); + return scalarWiderThan(TypeIdx, Ty.getSizeInBits()) && Predicate(Query); }, changeElementTo(typeIdx(TypeIdx), Ty)); } Index: llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp +++ llvm/lib/CodeGen/GlobalISel/LegalityPredicates.cpp @@ -80,22 +80,46 @@ }; } -LegalityPredicate LegalityPredicates::narrowerThan(unsigned TypeIdx, - unsigned Size) { +LegalityPredicate LegalityPredicates::elementTypeIs(unsigned TypeIdx, + LLT EltTy) { + return [=](const LegalityQuery &Query) { + const LLT QueryTy = Query.Types[TypeIdx]; + return QueryTy.isVector() && QueryTy.getElementType() == EltTy; + }; +} + +LegalityPredicate LegalityPredicates::scalarNarrowerThan(unsigned TypeIdx, + unsigned Size) { return [=](const LegalityQuery &Query) { const LLT QueryTy = Query.Types[TypeIdx]; return QueryTy.isScalar() && QueryTy.getSizeInBits() < Size; }; } -LegalityPredicate LegalityPredicates::widerThan(unsigned TypeIdx, - unsigned Size) { +LegalityPredicate LegalityPredicates::scalarWiderThan(unsigned TypeIdx, + unsigned Size) { return [=](const LegalityQuery &Query) { const LLT QueryTy = Query.Types[TypeIdx]; return QueryTy.isScalar() && QueryTy.getSizeInBits() > Size; }; } +LegalityPredicate LegalityPredicates::smallerThan(unsigned TypeIdx0, + unsigned TypeIdx1) { + return [=](const LegalityQuery &Query) { + return Query.Types[TypeIdx0].getSizeInBits() < + Query.Types[TypeIdx1].getSizeInBits(); + }; +} + +LegalityPredicate LegalityPredicates::largerThan(unsigned TypeIdx0, + unsigned TypeIdx1) { + return [=](const LegalityQuery &Query) { + return Query.Types[TypeIdx0].getSizeInBits() > + Query.Types[TypeIdx1].getSizeInBits(); + }; +} + LegalityPredicate LegalityPredicates::scalarOrEltNarrowerThan(unsigned TypeIdx, unsigned Size) { return [=](const LegalityQuery &Query) { Index: llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp +++ llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp @@ -158,13 +158,6 @@ }; } -static LegalityPredicate elementTypeIs(unsigned TypeIdx, LLT Type) { - return [=](const LegalityQuery &Query) { - const LLT QueryTy = Query.Types[TypeIdx]; - return QueryTy.isVector() && QueryTy.getElementType() == Type; - }; -} - static LegalityPredicate elementTypeIsLegal(unsigned TypeIdx) { return [=](const LegalityQuery &Query) { const LLT QueryTy = Query.Types[TypeIdx]; @@ -183,20 +176,6 @@ }; } -static LegalityPredicate smallerThan(unsigned TypeIdx0, unsigned TypeIdx1) { - return [=](const LegalityQuery &Query) { - return Query.Types[TypeIdx0].getSizeInBits() < - Query.Types[TypeIdx1].getSizeInBits(); - }; -} - -static LegalityPredicate greaterThan(unsigned TypeIdx0, unsigned TypeIdx1) { - return [=](const LegalityQuery &Query) { - return Query.Types[TypeIdx0].getSizeInBits() > - Query.Types[TypeIdx1].getSizeInBits(); - }; -} - AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_, const GCNTargetMachine &TM) : ST(ST_) { @@ -680,7 +659,7 @@ // TODO: Should have same legality without v_perm_b32 getActionDefinitionsBuilder(G_BSWAP) .legalFor({S32}) - .lowerIf(narrowerThan(0, 32)) + .lowerIf(scalarNarrowerThan(0, 32)) // FIXME: Fixing non-power-of-2 before clamp is workaround for // narrowScalar limitation. .widenScalarToNextPow2(0) @@ -707,7 +686,7 @@ [](const LegalityQuery &Query) { return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); }) - .narrowScalarIf(greaterThan(1, 0), + .narrowScalarIf(largerThan(1, 0), [](const LegalityQuery &Query) { return std::make_pair(1, LLT::scalar(Query.Types[0].getSizeInBits())); }); @@ -724,7 +703,7 @@ return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); }) .narrowScalarIf( - greaterThan(0, 1), + largerThan(0, 1), [](const LegalityQuery &Query) { return std::make_pair(0, LLT::scalar(Query.Types[1].getSizeInBits())); }); @@ -1238,7 +1217,7 @@ }) // Try to widen to s16 first for small types. // TODO: Only do this on targets with legal s16 shifts - .minScalarOrEltIf(narrowerThan(LitTyIdx, 16), LitTyIdx, S16) + .minScalarOrEltIf(scalarNarrowerThan(LitTyIdx, 16), LitTyIdx, S16) .widenScalarToNextPow2(LitTyIdx, /*Min*/ 16) .moreElementsIf(isSmallOddVector(BigTyIdx), oneMoreElement(BigTyIdx)) .fewerElementsIf(all(typeIs(0, S16), vectorWiderThan(1, 32),