diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -1358,6 +1358,26 @@ static const int LAST_MEM_INDEX_TYPE = UNSIGNED_UNSCALED + 1; +inline bool isIndexTypeScaled(MemIndexType IndexType) { + return IndexType == SIGNED_SCALED || IndexType == UNSIGNED_SCALED; +} + +inline bool isIndexTypeSigned(MemIndexType IndexType) { + return IndexType == SIGNED_SCALED || IndexType == SIGNED_UNSCALED; +} + +inline MemIndexType getSignedIndexType(MemIndexType IndexType) { + return isIndexTypeScaled(IndexType) ? SIGNED_SCALED : SIGNED_UNSCALED; +} + +inline MemIndexType getUnsignedIndexType(MemIndexType IndexType) { + return isIndexTypeScaled(IndexType) ? UNSIGNED_SCALED : UNSIGNED_UNSCALED; +} + +inline MemIndexType getUnscaledIndexType(MemIndexType IndexType) { + return isIndexTypeSigned(IndexType) ? SIGNED_UNSCALED : UNSIGNED_UNSCALED; +} + //===--------------------------------------------------------------------===// /// LoadExtType enum - This enum defines the three variants of LOADEXT /// (load with extension). diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2702,14 +2702,8 @@ ISD::MemIndexType getIndexType() const { return static_cast(LSBaseSDNodeBits.AddressingMode); } - bool isIndexScaled() const { - return (getIndexType() == ISD::SIGNED_SCALED) || - (getIndexType() == ISD::UNSIGNED_SCALED); - } - bool isIndexSigned() const { - return (getIndexType() == ISD::SIGNED_SCALED) || - (getIndexType() == ISD::SIGNED_UNSCALED); - } + bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); } + bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); } // In the both nodes address is Op1, mask is Op2: // VPGatherSDNode (Chain, base, index, scale, mask, vlen) @@ -2790,17 +2784,8 @@ ISD::MemIndexType getIndexType() const { return static_cast(LSBaseSDNodeBits.AddressingMode); } - void setIndexType(ISD::MemIndexType IndexType) { - LSBaseSDNodeBits.AddressingMode = IndexType; - } - bool isIndexScaled() const { - return (getIndexType() == ISD::SIGNED_SCALED) || - (getIndexType() == ISD::UNSIGNED_SCALED); - } - bool isIndexSigned() const { - return (getIndexType() == ISD::SIGNED_SCALED) || - (getIndexType() == ISD::SIGNED_UNSCALED); - } + bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); } + bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); } // In the both nodes address is Op1, mask is Op2: // MaskedGatherSDNode (Chain, passthru, mask, base, index, scale) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10471,24 +10471,27 @@ } // Fold sext/zext of index into index type. -bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index, - bool Scaled, bool Signed, SelectionDAG &DAG) { +bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, + SelectionDAG &DAG) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // It's always safe to look through zero extends. if (Index.getOpcode() == ISD::ZERO_EXTEND) { SDValue Op = Index.getOperand(0); - MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED); if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { + IndexType = ISD::getUnsignedIndexType(IndexType); Index = Op; return true; + } else if (ISD::isIndexTypeSigned(IndexType)) { + IndexType = ISD::getUnsignedIndexType(IndexType); + return true; } } // It's only safe to look through sign extends when Index is signed. - if (Index.getOpcode() == ISD::SIGN_EXTEND && Signed) { + if (Index.getOpcode() == ISD::SIGN_EXTEND && + ISD::isIndexTypeSigned(IndexType)) { SDValue Op = Index.getOperand(0); - MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED); if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { Index = Op; return true; @@ -10506,6 +10509,7 @@ SDValue Scale = MSC->getScale(); SDValue StoreVal = MSC->getValue(); SDValue BasePtr = MSC->getBasePtr(); + ISD::MemIndexType IndexType = MSC->getIndexType(); SDLoc DL(N); // Zap scatters with a zero mask. @@ -10514,17 +10518,16 @@ if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter( - DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), + DL, Ops, MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); } - if (refineIndexType(MSC, Index, MSC->isIndexScaled(), MSC->isIndexSigned(), - DAG)) { + if (refineIndexType(Index, IndexType, DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter( - DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), + DL, Ops, MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); } return SDValue(); @@ -10602,6 +10605,7 @@ SDValue Scale = MGT->getScale(); SDValue PassThru = MGT->getPassThru(); SDValue BasePtr = MGT->getBasePtr(); + ISD::MemIndexType IndexType = MGT->getIndexType(); SDLoc DL(N); // Zap gathers with a zero mask. @@ -10610,19 +10614,16 @@ if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), - MGT->getMemoryVT(), DL, Ops, - MGT->getMemOperand(), MGT->getIndexType(), - MGT->getExtensionType()); + return DAG.getMaskedGather( + DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, + Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); } - if (refineIndexType(MGT, Index, MGT->isIndexScaled(), MGT->isIndexSigned(), - DAG)) { + if (refineIndexType(Index, IndexType, DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), - MGT->getMemoryVT(), DL, Ops, - MGT->getMemOperand(), MGT->getIndexType(), - MGT->getExtensionType()); + return DAG.getMaskedGather( + DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, + Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); } return SDValue(); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -8613,14 +8613,9 @@ ISD::MemIndexType TargetLowering::getCanonicalIndexType(ISD::MemIndexType IndexType, EVT MemVT, SDValue Offsets) const { - bool IsScaledIndex = - (IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::UNSIGNED_SCALED); - bool IsSignedIndex = - (IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::SIGNED_UNSCALED); - // Scaling is unimportant for bytes, canonicalize to unscaled. - if (IsScaledIndex && MemVT.getScalarType() == MVT::i8) - return IsSignedIndex ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED; + if (ISD::isIndexTypeScaled(IndexType) && MemVT.getScalarType() == MVT::i8) + return ISD::getUnscaledIndexType(IndexType); return IndexType; } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4711,10 +4711,8 @@ return DAG.getMergeValues({Select, Load.getValue(1)}, DL); } - bool IsScaled = - IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; - bool IsSigned = - IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED; + bool IsScaled = MGT->isIndexScaled(); + bool IsSigned = MGT->isIndexSigned(); // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else // must be calculated before hand. @@ -4727,7 +4725,7 @@ Scale = DAG.getTargetConstant(1, DL, Scale.getValueType()); SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED; + IndexType = getUnscaledIndexType(IndexType); return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops, MGT->getMemOperand(), IndexType, ExtType); } @@ -4812,10 +4810,8 @@ EVT MemVT = MSC->getMemoryVT(); ISD::MemIndexType IndexType = MSC->getIndexType(); - bool IsScaled = - IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED; - bool IsSigned = - IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED; + bool IsScaled = MSC->isIndexScaled(); + bool IsSigned = MSC->isIndexSigned(); // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else // must be calculated before hand. @@ -4828,7 +4824,7 @@ Scale = DAG.getTargetConstant(1, DL, Scale.getValueType()); SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; - IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED; + IndexType = getUnscaledIndexType(IndexType); return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops, MSC->getMemOperand(), IndexType, MSC->isTruncatingStore());