diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2840,6 +2840,15 @@ return false; } + /// Return true if EXTRACT_SUBVECTOR is free for extracting this result type + /// from this source type with this index. This is needed because + /// EXTRACT_SUBVECTOR usually has custom lowering that depends on the index of + /// the first element, and only the target knows which lowering is free. + virtual bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { + return false; + } + /// Return true if EXTRACT_SUBVECTOR is cheap for extracting this result type /// from this source type with this index. This is needed because /// EXTRACT_SUBVECTOR usually has custom lowering that depends on the index of diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -20785,6 +20785,18 @@ assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 && "Extract index is not a multiple of the output vector length."); + auto GetExtractSubvectorCost = [&TLI, NarrowVT, WideVT](unsigned Index) { + if (TLI.isExtractSubvectorFree(NarrowVT, WideVT, Index)) + return 0; + if (TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index)) + return 1; + return 2; // Assume that all non-cheap subvector extracts are equal, + // and that replacing one non-cheap subvector extract + // with two cheap ones is a win. + }; + + int Budget = GetExtractSubvectorCost(FirstExtractedEltIdx); + int WideNumElts = WideVT.getVectorNumElements(); SmallVector NewMask; @@ -20844,8 +20856,9 @@ return SDValue(); // We can't handle more than two subvectors. // How many elements into the WideVT does this subvector start? int Index = NumEltsExtracted * OpSubvecIdx; - // Bail out if the extraction isn't going to be cheap. - if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index)) + // Bail out if the extraction exhausted our budget. + Budget -= GetExtractSubvectorCost(Index); + if (Budget < 0) return SDValue(); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -647,6 +647,11 @@ bool shouldConvertConstantLoadToIntImm(const APInt &Imm, Type *Ty) const override; + /// Return true if EXTRACT_SUBVECTOR is free for this result type + /// with this index. + bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const override; + /// Return true if EXTRACT_SUBVECTOR is cheap for this result type /// with this index. bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -12885,14 +12885,19 @@ return Shift < 3; } -bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, - unsigned Index) const { +bool AArch64TargetLowering::isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; return (Index == 0 || Index == ResVT.getVectorNumElements()); } +bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, + unsigned Index) const { + return isExtractSubvectorFree(ResVT, SrcVT, Index); +} + /// Turn vector tests of the signbit in the form of: /// xor (sra X, elt_size(X)-1), -1 /// into: diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -612,6 +612,11 @@ bool shouldConvertConstantLoadToIntImm(const APInt &Imm, Type *Ty) const override; + /// Return true if EXTRACT_SUBVECTOR is free for this result type + /// with this index. + bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const override; + /// Return true if EXTRACT_SUBVECTOR is cheap for this result type /// with this index. bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -20765,14 +20765,19 @@ return true; } -bool ARMTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, - unsigned Index) const { +bool ARMTargetLowering::isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; return (Index == 0 || Index == ResVT.getVectorNumElements()); } +bool ARMTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, + unsigned Index) const { + return isExtractSubvectorFree(ResVT, SrcVT, Index); +} + Instruction *ARMTargetLowering::makeDMB(IRBuilderBase &Builder, ARM_MB::MemBOpt Domain) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1347,6 +1347,11 @@ bool decomposeMulByConstant(LLVMContext &Context, EVT VT, SDValue C) const override; + /// Return true if EXTRACT_SUBVECTOR is free for this result type + /// with this index. + bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const override; + /// Return true if EXTRACT_SUBVECTOR is cheap for this result type /// with this index. bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -5590,16 +5590,27 @@ (1 - MulC).isPowerOf2() || (-(MulC + 1)).isPowerOf2(); } +bool X86TargetLowering::isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { + if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) + return false; + + return Index == 0; +} + bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; + if (isExtractSubvectorFree(ResVT, SrcVT, Index)) + return true; + // Mask vectors support all subregister combinations and operations that // extract half of vector. if (ResVT.getVectorElementType() == MVT::i1) - return Index == 0 || ((ResVT.getSizeInBits() == SrcVT.getSizeInBits()*2) && - (Index == ResVT.getVectorNumElements())); + return ((ResVT.getSizeInBits() == SrcVT.getSizeInBits() * 2) && + (Index == ResVT.getVectorNumElements())); return (Index % ResVT.getVectorNumElements()) == 0; }