diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -369,6 +369,7 @@ bool optimizeInst(Instruction *I, bool &ModifiedDT); bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy, unsigned AddrSpace); + bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr); bool optimizeInlineAsmInst(CallInst *CS); bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); bool optimizeExt(Instruction *&I); @@ -2034,7 +2035,12 @@ II->eraseFromParent(); return true; } + break; } + case Intrinsic::masked_gather: + return optimizeGatherScatterInst(II, II->getArgOperand(0)); + case Intrinsic::masked_scatter: + return optimizeGatherScatterInst(II, II->getArgOperand(1)); } SmallVector PtrOps; @@ -5176,6 +5182,124 @@ return true; } +bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, + Value *Ptr) { + const GetElementPtrInst *GEP = dyn_cast(Ptr); + if (!GEP || !GEP->hasIndices()) + return false; + + // If the GEP and the gather/scatter aren't in the same BB, don't optimize. + // FIXME: We should support this by sinking the GEP. + if (MemoryInst->getParent() != GEP->getParent()) + return false; + + SmallVector Ops(GEP->op_begin(), GEP->op_end()); + + bool RewriteGEP = false; + + if (Ops[0]->getType()->isVectorTy()) { + Ops[0] = const_cast(getSplatValue(Ops[0])); + if (!Ops[0]) + return false; + RewriteGEP = true; + } + + unsigned FinalIndex = Ops.size() - 1; + gep_type_iterator GTI = gep_type_begin(*GEP); + + // Ensure all but the last index is 0. + for (unsigned i = 1; i < FinalIndex; ++i, ++GTI) { + auto *C = dyn_cast(Ops[i]); + if (!C) + return false; + if (isa(C->getType())) + C = C->getSplatValue(); + auto *CI = dyn_cast_or_null(C); + if (!CI || !CI->isZero()) + return false; + // Scalarize the index if needed. + Ops[i] = CI; + } + + unsigned NumElts = Ptr->getType()->getVectorNumElements(); + + IRBuilder<> Builder(MemoryInst); + + Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); + + Value *NewAddr; + // We need different handling for structs and sequential types. + if (GTI.isStruct()) { + // Scalarize the struct index if needed. + if (Ops[FinalIndex]->getType()->isVectorTy()) + Ops[FinalIndex] = cast(Ops[FinalIndex])->getSplatValue(); + + NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); + Type *IndexTy = VectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + } else { + // Try to scalarize the final index. + if (Ops[FinalIndex]->getType()->isVectorTy()) { + if (Value *V = const_cast(getSplatValue(Ops[FinalIndex]))) { + auto *C = dyn_cast(V); + // Don't scalarize all zeros vector. + if (!C || !C->isZero()) { + Ops[FinalIndex] = V; + RewriteGEP = true; + } + } + } + + // If we made any changes or the we have extra operands, we need to generate + // new instructions. + if (!RewriteGEP && Ops.size() == 2) + return false; + + // If the final index isn't a vector, emit a scalar GEP containing all ops + // and a vector GEP with all zeroes final index. + if (!Ops[FinalIndex]->getType()->isVectorTy()) { + NewAddr = Builder.CreateGEP(Ops[0], makeArrayRef(Ops).drop_front()); + Type *IndexTy = VectorType::get(ScalarIndexTy, NumElts); + NewAddr = Builder.CreateGEP(NewAddr, Constant::getNullValue(IndexTy)); + } else { + Value *Base = Ops[0]; + Value *Index = Ops[FinalIndex]; + + // Create a scalar GEP if there are more than 2 operands. + if (Ops.size() != 2) { + // Replace the last index with 0. + Ops[FinalIndex] = Constant::getNullValue(ScalarIndexTy); + Base = Builder.CreateGEP(Base, makeArrayRef(Ops).drop_front()); + } + + NewAddr = Builder.CreateGEP(Base, Index); + } + } + + MemoryInst->replaceUsesOfWith(Ptr, NewAddr); + + // If we have no uses, recursively delete the value and all dead instructions + // using it. + if (Ptr->use_empty()) { + // This can cause recursive deletion, which can invalidate our iterator. + // Use a WeakTrackingVH to hold onto it in case this happens. + Value *CurValue = &*CurInstIterator; + WeakTrackingVH IterHandle(CurValue); + BasicBlock *BB = CurInstIterator->getParent(); + + RecursivelyDeleteTriviallyDeadInstructions(Ptr, TLInfo); + + if (IterHandle != CurValue) { + // If the iterator instruction was recursively deleted, start over at the + // start of the block. + CurInstIterator = BB->begin(); + SunkAddrs.clear(); + } + } + + return true; +} + /// If there are any memory operands, use OptimizeMemoryInst to sink their /// address computing into the block when possible / profitable. bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -518,7 +518,6 @@ void resolveOrClearDbgInfo(); SDValue getValue(const Value *V); - bool findValue(const Value *V) const; /// Return the SDNode for the specified IR value if it exists. SDNode *getNodeForIRValue(const Value *V) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1435,12 +1435,6 @@ return Val; } -// Return true if SDValue exists for the given Value -bool SelectionDAGBuilder::findValue(const Value *V) const { - return (NodeMap.find(V) != NodeMap.end()) || - (FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end()); -} - /// getNonRegisterValue - Return an SDValue for the given Value, but /// don't look in FuncInfo.ValueMap for a virtual register. SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) { @@ -4238,70 +4232,49 @@ // In all other cases the function returns 'false'. static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index, ISD::MemIndexType &IndexType, SDValue &Scale, - SelectionDAGBuilder *SDB) { + SelectionDAGBuilder *SDB, const BasicBlock *CurBB) { SelectionDAG& DAG = SDB->DAG; - LLVMContext &Context = *DAG.getContext(); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + const DataLayout &DL = DAG.getDataLayout(); assert(Ptr->getType()->isVectorTy() && "Uexpected pointer type"); - const GetElementPtrInst *GEP = dyn_cast(Ptr); - if (!GEP) - return false; - const Value *BasePtr = GEP->getPointerOperand(); - if (BasePtr->getType()->isVectorTy()) { - BasePtr = getSplatValue(BasePtr); - if (!BasePtr) + // Handle splat constant pointer. + if (auto *C = dyn_cast(Ptr)) { + C = C->getSplatValue(); + if (!C) return false; - } - unsigned FinalIndex = GEP->getNumOperands() - 1; - Value *IndexVal = GEP->getOperand(FinalIndex); - gep_type_iterator GTI = gep_type_begin(*GEP); + Base = SDB->getValue(C); - // Ensure all the other indices are 0. - for (unsigned i = 1; i < FinalIndex; ++i, ++GTI) { - auto *C = dyn_cast(GEP->getOperand(i)); - if (!C) - return false; - if (isa(C->getType())) - C = C->getSplatValue(); - auto *CI = dyn_cast_or_null(C); - if (!CI || !CI->isZero()) - return false; + unsigned NumElts = Ptr->getType()->getVectorNumElements(); + EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts); + Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT); + IndexType = ISD::SIGNED_SCALED; + Scale = DAG.getTargetConstant(1, SDB->getCurSDLoc(), TLI.getPointerTy(DL)); + return true; } - // The operands of the GEP may be defined in another basic block. - // In this case we'll not find nodes for the operands. - if (!SDB->findValue(BasePtr)) + const GetElementPtrInst *GEP = dyn_cast(Ptr); + if (!GEP || GEP->getParent() != CurBB) return false; - Constant *C = dyn_cast(IndexVal); - if (!C && !SDB->findValue(IndexVal)) + + if (GEP->getNumOperands() != 2) return false; - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - const DataLayout &DL = DAG.getDataLayout(); - StructType *STy = GTI.getStructTypeOrNull(); + const Value *BasePtr = GEP->getPointerOperand(); + const Value *IndexVal = GEP->getOperand(GEP->getNumOperands() - 1); + + // Make sure the base is scalar and the index is a vector. + if (BasePtr->getType()->isVectorTy() || !IndexVal->getType()->isVectorTy()) + return false; - if (STy) { - const StructLayout *SL = DL.getStructLayout(STy); - unsigned Field = cast(IndexVal)->getUniqueInteger().getZExtValue(); - Scale = DAG.getTargetConstant(1, SDB->getCurSDLoc(), TLI.getPointerTy(DL)); - Index = DAG.getConstant(SL->getElementOffset(Field), - SDB->getCurSDLoc(), TLI.getPointerTy(DL)); - } else { - Scale = DAG.getTargetConstant( - DL.getTypeAllocSize(GEP->getResultElementType()), - SDB->getCurSDLoc(), TLI.getPointerTy(DL)); - Index = SDB->getValue(IndexVal); - } Base = SDB->getValue(BasePtr); + Index = SDB->getValue(IndexVal); IndexType = ISD::SIGNED_SCALED; - - if (STy || !Index.getValueType().isVector()) { - unsigned GEPWidth = GEP->getType()->getVectorNumElements(); - EVT VT = EVT::getVectorVT(Context, Index.getValueType(), GEPWidth); - Index = DAG.getSplatBuildVector(VT, SDLoc(Index), Index); - } + Scale = DAG.getTargetConstant( + DL.getTypeAllocSize(GEP->getResultElementType()), + SDB->getCurSDLoc(), TLI.getPointerTy(DL)); return true; } @@ -4325,7 +4298,8 @@ SDValue Index; ISD::MemIndexType IndexType; SDValue Scale; - bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this); + bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this, + I.getParent()); unsigned AS = Ptr->getType()->getScalarType()->getPointerAddressSpace(); MachineMemOperand *MMO = DAG.getMachineFunction(). @@ -4440,7 +4414,8 @@ SDValue Index; ISD::MemIndexType IndexType; SDValue Scale; - bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this); + bool UniformBase = getUniformBase(Ptr, Base, Index, IndexType, Scale, this, + I.getParent()); unsigned AS = Ptr->getType()->getScalarType()->getPointerAddressSpace(); MachineMemOperand *MMO = DAG.getMachineFunction(). diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2883,7 +2883,7 @@ for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) Assert(Call.getArgOperand(i)->getType() == FTy->getParamType(i), "Call parameter type does not match function signature!", - Call.getArgOperand(i), FTy->getParamType(i), Call); + Call.getArgOperand(i)->getType(), FTy->getParamType(i), Call); AttributeList Attrs = Call.getAttributes(); diff --git a/llvm/test/CodeGen/X86/masked_gather.ll b/llvm/test/CodeGen/X86/masked_gather.ll --- a/llvm/test/CodeGen/X86/masked_gather.ll +++ b/llvm/test/CodeGen/X86/masked_gather.ll @@ -1721,11 +1721,10 @@ ; AVX512-NEXT: vptestnmd %zmm0, %zmm0, %k0 ; AVX512-NEXT: kshiftlw $8, %k0, %k0 ; AVX512-NEXT: kshiftrw $8, %k0, %k1 -; AVX512-NEXT: vpbroadcastd {{.*#+}} zmm0 = [3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3] +; AVX512-NEXT: vpxor %xmm0, %xmm0, %xmm0 ; AVX512-NEXT: kmovw %k1, %k2 -; AVX512-NEXT: vpgatherdd c(,%zmm0,4), %zmm1 {%k2} -; AVX512-NEXT: vpbroadcastd {{.*#+}} zmm0 = [28,28,28,28,28,28,28,28,28,28,28,28,28,28,28,28] -; AVX512-NEXT: vpgatherdd c(,%zmm0), %zmm2 {%k1} +; AVX512-NEXT: vpgatherdd c+12(,%zmm0), %zmm1 {%k2} +; AVX512-NEXT: vpgatherdd c+28(,%zmm0), %zmm2 {%k1} ; AVX512-NEXT: vpaddd %ymm2, %ymm2, %ymm0 ; AVX512-NEXT: vpaddd %ymm0, %ymm1, %ymm0 ; AVX512-NEXT: retq diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll --- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -638,30 +638,38 @@ define <16 x float> @test11(float* %base, i32 %ind) { ; KNL_64-LABEL: test11: ; KNL_64: # %bb.0: -; KNL_64-NEXT: vpbroadcastd %esi, %zmm1 +; KNL_64-NEXT: movslq %esi, %rax +; KNL_64-NEXT: leaq (%rdi,%rax,4), %rax +; KNL_64-NEXT: vxorps %xmm1, %xmm1, %xmm1 ; KNL_64-NEXT: kxnorw %k0, %k0, %k1 -; KNL_64-NEXT: vgatherdps (%rdi,%zmm1,4), %zmm0 {%k1} +; KNL_64-NEXT: vgatherdps (%rax,%zmm1,4), %zmm0 {%k1} ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: test11: ; KNL_32: # %bb.0: ; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; KNL_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1 +; KNL_32-NEXT: shll $2, %eax +; KNL_32-NEXT: addl {{[0-9]+}}(%esp), %eax +; KNL_32-NEXT: vxorps %xmm1, %xmm1, %xmm1 ; KNL_32-NEXT: kxnorw %k0, %k0, %k1 ; KNL_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1} ; KNL_32-NEXT: retl ; ; SKX-LABEL: test11: ; SKX: # %bb.0: -; SKX-NEXT: vpbroadcastd %esi, %zmm1 +; SKX-NEXT: movslq %esi, %rax +; SKX-NEXT: leaq (%rdi,%rax,4), %rax +; SKX-NEXT: vxorps %xmm1, %xmm1, %xmm1 ; SKX-NEXT: kxnorw %k0, %k0, %k1 -; SKX-NEXT: vgatherdps (%rdi,%zmm1,4), %zmm0 {%k1} +; SKX-NEXT: vgatherdps (%rax,%zmm1,4), %zmm0 {%k1} ; SKX-NEXT: retq ; ; SKX_32-LABEL: test11: ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; SKX_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1 +; SKX_32-NEXT: shll $2, %eax +; SKX_32-NEXT: addl {{[0-9]+}}(%esp), %eax +; SKX_32-NEXT: vxorps %xmm1, %xmm1, %xmm1 ; SKX_32-NEXT: kxnorw %k0, %k0, %k1 ; SKX_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1} ; SKX_32-NEXT: retl diff --git a/llvm/test/CodeGen/X86/pr45067.ll b/llvm/test/CodeGen/X86/pr45067.ll --- a/llvm/test/CodeGen/X86/pr45067.ll +++ b/llvm/test/CodeGen/X86/pr45067.ll @@ -6,13 +6,13 @@ define void @foo(<8 x i32>* %x, <8 x i1> %y) { ; CHECK-LABEL: foo: ; CHECK: ## %bb.0: -; CHECK-NEXT: vpcmpeqd %xmm1, %xmm1, %xmm1 -; CHECK-NEXT: vpbroadcastq _global@{{.*}}(%rip), %ymm2 -; CHECK-NEXT: vpgatherqd %xmm1, (,%ymm2), %xmm3 +; CHECK-NEXT: vpcmpeqd %ymm1, %ymm1, %ymm1 +; CHECK-NEXT: vpxor %xmm2, %xmm2, %xmm2 +; CHECK-NEXT: movq _global@{{.*}}(%rip), %rax +; CHECK-NEXT: vpgatherdd %ymm1, (%rax,%ymm2), %ymm3 ; CHECK-NEXT: vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero ; CHECK-NEXT: vpslld $31, %ymm0, %ymm0 -; CHECK-NEXT: vinserti128 $1, %xmm3, %ymm3, %ymm1 -; CHECK-NEXT: vpmaskmovd %ymm1, %ymm0, (%rdi) +; CHECK-NEXT: vpmaskmovd %ymm3, %ymm0, (%rdi) ; CHECK-NEXT: ud2 %tmp = call <8 x i32> @llvm.masked.gather.v8i32.v8p0i32(<8 x i32*> , i32 4, <8 x i1> , <8 x i32> undef) call void @llvm.masked.store.v8i32.p0v8i32(<8 x i32> %tmp, <8 x i32>* %x, i32 4, <8 x i1> %y)