diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1831,6 +1831,11 @@ /// which is split (or expanded) into two not necessarily identical pieces. std::pair GetSplitDestVTs(const EVT &VT) const; + /// Compute the VTs needed for the low/hi parts of a type, dependent on an + /// enveloping VT that has been split into two identical pieces. + std::pair GetDependentSplitDestVTs(const EVT &VT, const EVT &LoVT, + const EVT &HiVT) const; + /// Split the vector with EXTRACT_SUBVECTOR using the provides /// VTs and return the low/high part. std::pair SplitVector(const SDValue &N, const SDLoc &DL, diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1571,7 +1571,8 @@ EVT MemoryVT = MLD->getMemoryVT(); EVT LoMemVT, HiMemVT; - std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + std::tie(LoMemVT, HiMemVT) = + DAG.GetDependentSplitDestVTs(MemoryVT, LoVT, HiVT); SDValue PassThruLo, PassThruHi; if (getTypeAction(PassThru.getValueType()) == TargetLowering::TypeSplitVector) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9461,6 +9461,33 @@ return std::make_pair(LoVT, HiVT); } +/// GetDependentSplitDestVTs - Compute the VTs needed for the low/hi parts +/// of a type, dependent on an enveloping VT that has been split into two +/// identical pieces. +std::pair +SelectionDAG::GetDependentSplitDestVTs(const EVT &VT, const EVT &LoVT, + const EVT &HiVT) const { + assert(VT.isVector()); + assert(VT.getVectorElementType() == LoVT.getVectorElementType()); + assert(VT.getVectorElementType() == HiVT.getVectorElementType()); + assert(VT.getVectorNumElements() <= + (LoVT.getVectorNumElements() + HiVT.getVectorNumElements()) && + "Dependent VT does not have sufficient elements"); + // Examples: + // custom VL=5 with enveloping VL=8, split 4/4 yields 4/1 + // custom VL=6 with enveloping VL=8, split 4/4 yields 4/2 + // custom VL=13 with enveloping VL=16, split 8/8 yields 8/5 + // custom VL=14 with enveloping VL=16, split 8/8 yields 8/6 + // etc. + bool IsScalable = VT.isScalableVector(); + EVT DepLoVT = EVT::getVectorVT(*getContext(), LoVT.getVectorElementType(), + LoVT.getVectorNumElements(), IsScalable); + EVT DepHiVT = EVT::getVectorVT( + *getContext(), HiVT.getVectorElementType(), + VT.getVectorNumElements() - LoVT.getVectorNumElements(), IsScalable); + return std::make_pair(DepLoVT, DepHiVT); +} + /// SplitVector - Split the vector with EXTRACT_SUBVECTOR and return the /// low/high part. std::pair diff --git a/llvm/test/CodeGen/X86/pr45563-2.ll b/llvm/test/CodeGen/X86/pr45563-2.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/pr45563-2.ll @@ -0,0 +1,41 @@ +; REQUIRES: asserts +; RUN: llc < %s -debug-only=isel -O3 -mattr=avx 2>&1 | FileCheck %s + +; Bug 45563: +; The SplitVecRes_MLOAD method should split a extended value type +; according to the halving of the enveloping type to avoid all sorts +; of inconsistencies downstream. For example for a extended value type +; with VL=14 and enveloping type VL=16 that is split 8/8, the extended +; type should be split 8/6 and not 7/7. + +define <9 x float> @mload_split9(<9 x i1> %mask, <9 x float>* %addr, <9 x float> %dst) { +; CHECK-LABEL: Type-legalized selection DAG: %bb.0 'mload_split9:' +; CHECK-DAG: v8f32,ch = masked_load<(load 32 from %ir.addr, align 4) +; CHECK-DAG: v8f32,ch = masked_load<(load 4 from %ir.addr + 32)> +; CHECK: Optimized type-legalized selection DAG: %bb.0 'mload_split9:' + %res = call <9 x float> @llvm.masked.load.v9f32.p0v9f32(<9 x float>* %addr, i32 4, <9 x i1>%mask, <9 x float> %dst) + ret <9 x float> %res +} + +define <13 x float> @mload_split13(<13 x i1> %mask, <13 x float>* %addr, <13 x float> %dst) { +; CHECK-LABEL: Type-legalized selection DAG: %bb.0 'mload_split13:' +; CHECK-DAG: v8f32,ch = masked_load<(load 32 from %ir.addr, align 4)> +; CHECK-DAG: v8f32,ch = masked_load<(load 20 from %ir.addr + 32, align 4)> +; CHECK: Optimized type-legalized selection DAG: %bb.0 'mload_split13:' + %res = call <13 x float> @llvm.masked.load.v13f32.p0v13f32(<13 x float>* %addr, i32 4, <13 x i1>%mask, <13 x float> %dst) + ret <13 x float> %res +} + +define <14 x float> @mload_split14(<14 x i1> %mask, <14 x float>* %addr, <14 x float> %dst) { +; CHECK-LABEL: Type-legalized selection DAG: %bb.0 'mload_split14:' +; CHECK-DAG: v8f32,ch = masked_load<(load 32 from %ir.addr, align 4)> t +; CHECK-DAG: v8f32,ch = masked_load<(load 24 from %ir.addr + 32, align 4)> +; CHECK: Optimized type-legalized selection DAG: %bb.0 'mload_split14:' + %res = call <14 x float> @llvm.masked.load.v14f32.p0v14f32(<14 x float>* %addr, i32 4, <14 x i1>%mask, <14 x float> %dst) + ret <14 x float> %res +} + +declare <9 x float> @llvm.masked.load.v9f32.p0v9f32(<9 x float>* %addr, i32 %align, <9 x i1> %mask, <9 x float> %dst) +declare <13 x float> @llvm.masked.load.v13f32.p0v13f32(<13 x float>* %addr, i32 %align, <13 x i1> %mask, <13 x float> %dst) +declare <14 x float> @llvm.masked.load.v14f32.p0v14f32(<14 x float>* %addr, i32 %align, <14 x i1> %mask, <14 x float> %dst) +