Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3080,56 +3080,74 @@ setValue(&I, StoreNode); } +// Get a unform base for the Gather/Scatter input. // Gather/scatter receive a vector of pointers. // This vector of pointers may be represented as a base pointer + vector of -// indices, it depends on GEP and instruction preceeding GEP -// that calculates indices +// indices. Usually, the vector of pointers comes from a 'getelementptr' +// instruction. Extracting a uniform base depends on the GEP and the instruction +// preceeding GEP that calculates indices. static bool getUniformBase(Value *& Ptr, SDValue& Base, SDValue& Index, SelectionDAGBuilder* SDB) { + SelectionDAG& DAG = SDB->DAG; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type"); GetElementPtrInst *Gep = dyn_cast(Ptr); if (!Gep || Gep->getNumOperands() > 2) return false; - ShuffleVectorInst *ShuffleInst = - dyn_cast(Gep->getPointerOperand()); - if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() || - cast(ShuffleInst->getOperand(0))->getOpcode() != - Instruction::InsertElement) - return false; - Ptr = cast(ShuffleInst->getOperand(0))->getOperand(1); + Value *GepBasePtr = Gep->getPointerOperand(); + Value *IndexVal = Gep->getOperand(1); + // The operands of the GEP may be defined in another basic block. + // In this case we'll not find nodes for the operands. + if (!SDB->findValue(GepBasePtr) || !SDB->findValue(IndexVal)) + return false; - SelectionDAG& DAG = SDB->DAG; - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - // Check is the Ptr is inside current basic block - // If not, look for the shuffle instruction - if (SDB->findValue(Ptr)) + // If GEP base is scalar - this is the uniform base we are looking for. + if (!GepBasePtr->getType()->isVectorTy()) { + Ptr = GepBasePtr; Base = SDB->getValue(Ptr); - else if (SDB->findValue(ShuffleInst)) { - SDValue ShuffleNode = SDB->getValue(ShuffleInst); - SDLoc sdl = ShuffleNode; - Base = DAG.getNode( - ISD::EXTRACT_VECTOR_ELT, sdl, - ShuffleNode.getValueType().getScalarType(), ShuffleNode, - DAG.getConstant(0, sdl, TLI.getVectorIdxTy(DAG.getDataLayout()))); - SDB->setValue(Ptr, Base); } - else - return false; - - Value *IndexVal = Gep->getOperand(1); - if (SDB->findValue(IndexVal)) { - Index = SDB->getValue(IndexVal); + else { + // The base is a vector. But may be this vector is splat. + // Try to find a preceding broadcast. + ShuffleVectorInst *ShuffleInst = dyn_cast(GepBasePtr); + if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() || + !isa(ShuffleInst->getOperand(0))) + return false; - if (SExtInst* Sext = dyn_cast(IndexVal)) { - IndexVal = Sext->getOperand(0); - if (SDB->findValue(IndexVal)) - Index = SDB->getValue(IndexVal); + Ptr = cast(ShuffleInst->getOperand(0))->getOperand(1); + // Check if the Ptr is inside current basic block. + // If not, look for the shuffle instruction + if (SDB->findValue(Ptr)) + Base = SDB->getValue(Ptr); + else { + SDValue ShuffleNode = SDB->getValue(ShuffleInst); + SDLoc sdl(ShuffleNode); + EVT IdxVT = TLI.getVectorIdxTy(DAG.getDataLayout()); + Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, + ShuffleNode.getValueType().getScalarType(), + ShuffleNode, DAG.getConstant(0, sdl, IdxVT)); + SDB->setValue(Ptr, Base); } - return true; } - return false; + + Index = SDB->getValue(IndexVal); + + // Suppress sign extension. + if (SExtInst* Sext = dyn_cast(IndexVal)) { + IndexVal = Sext->getOperand(0); + if (SDB->findValue(IndexVal)) + Index = SDB->getValue(IndexVal); + } + if (!Index.getValueType().isVector()) { + unsigned GEPWidth = Gep->getType()->getVectorNumElements(); + MVT VT = MVT::getVectorVT(Index.getValueType().getSimpleVT(), GEPWidth); + SmallVector Ops(GEPWidth, Index); + Index = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(Index), VT, Ops); + } + return true; } void SelectionDAGBuilder::visitMaskedScatter(const CallInst &I) { Index: test/CodeGen/X86/masked_gather_scatter.ll =================================================================== --- test/CodeGen/X86/masked_gather_scatter.ll +++ test/CodeGen/X86/masked_gather_scatter.ll @@ -140,3 +140,77 @@ %res = add <16 x i32> %gt1, %gt2 ret <16 x i32> %res } + +%struct.RT = type { i8, [10 x [20 x i32]], i8 } +%struct.ST = type { i32, double, %struct.RT } + +; Masked gather for agregate types +; Test9 and Test10 should give the same result (scalar and vector indices in GEP) + +; KNL-LABEL: test9 +; KNL: vpbroadcastq %rdi, %zmm +; KNL: vpmovsxdq +; KNL: vpbroadcastq +; KNL: vpmuludq +; KNL: vpaddq +; KNL: vpaddq +; KNL: vpaddq +; KNL: vpaddq +; KNL: vpgatherqd (,%zmm + +define <8 x i32> @test9(%struct.ST* %base, <8 x i64> %ind1, <8 x i32>%ind5) { +entry: + %broadcast.splatinsert = insertelement <8 x %struct.ST*> undef, %struct.ST* %base, i32 0 + %broadcast.splat = shufflevector <8 x %struct.ST*> %broadcast.splatinsert, <8 x %struct.ST*> undef, <8 x i32> zeroinitializer + + %arrayidx = getelementptr %struct.ST, <8 x %struct.ST*> %broadcast.splat, <8 x i64> %ind1, <8 x i32> , <8 x i32>, <8 x i32> %ind5, <8 x i64> + %res = call <8 x i32 > @llvm.masked.gather.v8i32(<8 x i32*>%arrayidx, i32 4, <8 x i1> , <8 x i32> undef) + ret <8 x i32> %res +} + +; KNL-LABEL: test10 +; KNL: vpbroadcastq %rdi, %zmm +; KNL: vpmovsxdq +; KNL: vpbroadcastq +; KNL: vpmuludq +; KNL: vpaddq +; KNL: vpaddq +; KNL: vpaddq +; KNL: vpaddq +; KNL: vpgatherqd (,%zmm +define <8 x i32> @test10(%struct.ST* %base, <8 x i64> %i1, <8 x i32>%ind5) { +entry: + %broadcast.splatinsert = insertelement <8 x %struct.ST*> undef, %struct.ST* %base, i32 0 + %broadcast.splat = shufflevector <8 x %struct.ST*> %broadcast.splatinsert, <8 x %struct.ST*> undef, <8 x i32> zeroinitializer + + %arrayidx = getelementptr %struct.ST, <8 x %struct.ST*> %broadcast.splat, <8 x i64> %i1, i32 2, i32 1, <8 x i32> %ind5, i64 13 + %res = call <8 x i32 > @llvm.masked.gather.v8i32(<8 x i32*>%arrayidx, i32 4, <8 x i1> , <8 x i32> undef) + ret <8 x i32> %res +} + +; Splat index in GEP, requires broadcast +; KNL-LABEL: test11 +; KNL: vpbroadcastd %esi, %zmm +; KNL: vgatherdps (%rdi,%zmm +define <16 x float> @test11(float* %base, i32 %ind) { + + %broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0 + %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer + + %gep.random = getelementptr float, <16 x float*> %broadcast.splat, i32 %ind + + %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> , <16 x float> undef) + ret <16 x float>%res +} + +; We are checking the uniform base here. It is taken directly from input to vgatherdps +; KNL-LABEL: test12 +; KNL: vgatherdps (%rdi,%zmm +define <16 x float> @test12(float* %base, <16 x i32> %ind) { + + %sext_ind = sext <16 x i32> %ind to <16 x i64> + %gep.random = getelementptr float, float *%base, <16 x i64> %sext_ind + + %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> , <16 x float> undef) + ret <16 x float>%res +}