diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2840,6 +2840,15 @@ return false; } + /// Return true if EXTRACT_SUBVECTOR is free for extracting this result type + /// from this source type with this index. This is needed because + /// EXTRACT_SUBVECTOR usually has custom lowering that depends on the index of + /// the first element, and only the target knows which lowering is free. + virtual bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { + return false; + } + /// Return true if EXTRACT_SUBVECTOR is cheap for extracting this result type /// from this source type with this index. This is needed because /// EXTRACT_SUBVECTOR usually has custom lowering that depends on the index of diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -20785,6 +20785,18 @@ assert((FirstExtractedEltIdx % NumEltsExtracted) == 0 && "Extract index is not a multiple of the output vector length."); + auto GetExtractSubvectorCost = [&TLI, NarrowVT, WideVT](unsigned Index) { + if (TLI.isExtractSubvectorFree(NarrowVT, WideVT, Index)) + return 0; + if (TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index)) + return 1; + return 2; // Assume that all non-cheap subvector extracts are equal, + // and that replacing one non-cheap subvector extract + // with two cheap ones is a win. + }; + + int Budget = GetExtractSubvectorCost(FirstExtractedEltIdx); + int WideNumElts = WideVT.getVectorNumElements(); SmallVector NewMask; @@ -20832,10 +20844,6 @@ continue; } - // Profitability check: only deal with extractions from the first subvector. - if (OpSubvecIdx != 0) - return SDValue(); - const std::pair DemandedSubvector = std::make_pair(Op, OpSubvecIdx); @@ -20844,8 +20852,9 @@ return SDValue(); // We can't handle more than two subvectors. // How many elements into the WideVT does this subvector start? int Index = NumEltsExtracted * OpSubvecIdx; - // Bail out if the extraction isn't going to be cheap. - if (!TLI.isExtractSubvectorCheap(NarrowVT, WideVT, Index)) + // Bail out if the extraction exhausted our budget. + Budget -= GetExtractSubvectorCost(Index); + if (Budget < 0) return SDValue(); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -647,6 +647,11 @@ bool shouldConvertConstantLoadToIntImm(const APInt &Imm, Type *Ty) const override; + /// Return true if EXTRACT_SUBVECTOR is free for this result type + /// with this index. + bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const override; + /// Return true if EXTRACT_SUBVECTOR is cheap for this result type /// with this index. bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -12885,14 +12885,19 @@ return Shift < 3; } -bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, - unsigned Index) const { +bool AArch64TargetLowering::isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; return (Index == 0 || Index == ResVT.getVectorNumElements()); } +bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, + unsigned Index) const { + return isExtractSubvectorFree(ResVT, SrcVT, Index); +} + /// Turn vector tests of the signbit in the form of: /// xor (sra X, elt_size(X)-1), -1 /// into: diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -612,6 +612,11 @@ bool shouldConvertConstantLoadToIntImm(const APInt &Imm, Type *Ty) const override; + /// Return true if EXTRACT_SUBVECTOR is free for this result type + /// with this index. + bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const override; + /// Return true if EXTRACT_SUBVECTOR is cheap for this result type /// with this index. bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -20765,14 +20765,19 @@ return true; } -bool ARMTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, - unsigned Index) const { +bool ARMTargetLowering::isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; return (Index == 0 || Index == ResVT.getVectorNumElements()); } +bool ARMTargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, + unsigned Index) const { + return isExtractSubvectorFree(ResVT, SrcVT, Index); +} + Instruction *ARMTargetLowering::makeDMB(IRBuilderBase &Builder, ARM_MB::MemBOpt Domain) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1347,6 +1347,11 @@ bool decomposeMulByConstant(LLVMContext &Context, EVT VT, SDValue C) const override; + /// Return true if EXTRACT_SUBVECTOR is free for this result type + /// with this index. + bool isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const override; + /// Return true if EXTRACT_SUBVECTOR is cheap for this result type /// with this index. bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -5590,16 +5590,27 @@ (1 - MulC).isPowerOf2() || (-(MulC + 1)).isPowerOf2(); } +bool X86TargetLowering::isExtractSubvectorFree(EVT ResVT, EVT SrcVT, + unsigned Index) const { + if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) + return false; + + return Index == 0; +} + bool X86TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, unsigned Index) const { if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT)) return false; + if (isExtractSubvectorFree(ResVT, SrcVT, Index)) + return true; + // Mask vectors support all subregister combinations and operations that // extract half of vector. if (ResVT.getVectorElementType() == MVT::i1) - return Index == 0 || ((ResVT.getSizeInBits() == SrcVT.getSizeInBits()*2) && - (Index == ResVT.getVectorNumElements())); + return ((ResVT.getSizeInBits() == SrcVT.getSizeInBits() * 2) && + (Index == ResVT.getVectorNumElements())); return (Index % ResVT.getVectorNumElements()) == 0; } diff --git a/llvm/test/CodeGen/AArch64/arm64-neon-copy.ll b/llvm/test/CodeGen/AArch64/arm64-neon-copy.ll --- a/llvm/test/CodeGen/AArch64/arm64-neon-copy.ll +++ b/llvm/test/CodeGen/AArch64/arm64-neon-copy.ll @@ -282,7 +282,7 @@ define <1 x double> @ins2f1(<2 x double> %tmp1, <1 x double> %tmp2) { ; CHECK-LABEL: ins2f1: ; CHECK: // %bb.0: -; CHECK-NEXT: dup v0.2d, v0.d[1] +; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret %tmp3 = extractelement <2 x double> %tmp1, i32 1 diff --git a/llvm/test/CodeGen/ARM/fp16-insert-extract.ll b/llvm/test/CodeGen/ARM/fp16-insert-extract.ll --- a/llvm/test/CodeGen/ARM/fp16-insert-extract.ll +++ b/llvm/test/CodeGen/ARM/fp16-insert-extract.ll @@ -176,7 +176,7 @@ ; CHECKHARD-NEXT: vmov r0, s12 ; CHECKHARD-NEXT: vext.16 d16, d4, d5, #2 ; CHECKHARD-NEXT: vmovx.f16 s12, s4 -; CHECKHARD-NEXT: vdup.16 q11, d3[1] +; CHECKHARD-NEXT: vdup.32 d21, d3[1] ; CHECKHARD-NEXT: vrev32.16 d17, d16 ; CHECKHARD-NEXT: vext.16 d16, d16, d17, #3 ; CHECKHARD-NEXT: vrev32.16 d17, d3 @@ -207,8 +207,8 @@ ; CHECKHARD-NEXT: vmov r0, s11 ; CHECKHARD-NEXT: vmov.16 d20[3], r0 ; CHECKHARD-NEXT: vmov r0, s10 -; CHECKHARD-NEXT: vext.16 d20, d20, d22, #1 -; CHECKHARD-NEXT: vdup.16 q11, d3[2] +; CHECKHARD-NEXT: vext.16 d20, d20, d21, #1 +; CHECKHARD-NEXT: vdup.32 d21, d3[2] ; CHECKHARD-NEXT: vext.16 d19, d20, d20, #3 ; CHECKHARD-NEXT: vadd.f16 q8, q8, q9 ; CHECKHARD-NEXT: vext.16 d18, d0, d1, #2 @@ -223,7 +223,7 @@ ; CHECKHARD-NEXT: vmov.16 d20[2], r0 ; CHECKHARD-NEXT: vmov r0, s0 ; CHECKHARD-NEXT: vmov.16 d20[3], r0 -; CHECKHARD-NEXT: vext.16 d20, d20, d22, #1 +; CHECKHARD-NEXT: vext.16 d20, d20, d21, #1 ; CHECKHARD-NEXT: vext.16 d19, d20, d20, #3 ; CHECKHARD-NEXT: vadd.f16 q0, q8, q9 ; CHECKHARD-NEXT: bx lr @@ -235,7 +235,7 @@ ; CHECKSOFT-NEXT: vmov r0, s12 ; CHECKSOFT-NEXT: vext.16 d16, d4, d5, #2 ; CHECKSOFT-NEXT: vmovx.f16 s12, s4 -; CHECKSOFT-NEXT: vdup.16 q11, d3[1] +; CHECKSOFT-NEXT: vdup.32 d21, d3[1] ; CHECKSOFT-NEXT: vrev32.16 d17, d16 ; CHECKSOFT-NEXT: vext.16 d16, d16, d17, #3 ; CHECKSOFT-NEXT: vrev32.16 d17, d3 @@ -266,8 +266,8 @@ ; CHECKSOFT-NEXT: vmov r0, s11 ; CHECKSOFT-NEXT: vmov.16 d20[3], r0 ; CHECKSOFT-NEXT: vmov r0, s10 -; CHECKSOFT-NEXT: vext.16 d20, d20, d22, #1 -; CHECKSOFT-NEXT: vdup.16 q11, d3[2] +; CHECKSOFT-NEXT: vext.16 d20, d20, d21, #1 +; CHECKSOFT-NEXT: vdup.32 d21, d3[2] ; CHECKSOFT-NEXT: vext.16 d19, d20, d20, #3 ; CHECKSOFT-NEXT: vadd.f16 q8, q8, q9 ; CHECKSOFT-NEXT: vext.16 d18, d0, d1, #2 @@ -282,7 +282,7 @@ ; CHECKSOFT-NEXT: vmov.16 d20[2], r0 ; CHECKSOFT-NEXT: vmov r0, s0 ; CHECKSOFT-NEXT: vmov.16 d20[3], r0 -; CHECKSOFT-NEXT: vext.16 d20, d20, d22, #1 +; CHECKSOFT-NEXT: vext.16 d20, d20, d21, #1 ; CHECKSOFT-NEXT: vext.16 d19, d20, d20, #3 ; CHECKSOFT-NEXT: vadd.f16 q0, q8, q9 ; CHECKSOFT-NEXT: bx lr