diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -193,6 +193,7 @@ bool visitCastInst(CastInst &CI); bool visitBitCastInst(BitCastInst &BCI); bool visitInsertElementInst(InsertElementInst &IEI); + bool visitExtractElementInst(ExtractElementInst &EEI); bool visitShuffleVectorInst(ShuffleVectorInst &SVI); bool visitPHINode(PHINode &PHI); bool visitLoadInst(LoadInst &LI); @@ -755,7 +756,7 @@ Value *NewElt = IEI.getOperand(1); Value *InsIdx = IEI.getOperand(2); - if (isa(InsIdx)) + if (isa(InsIdx)) return false; ValueVector Res; @@ -771,6 +772,29 @@ return true; } +bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { + VectorType *VT = dyn_cast(EEI.getOperand(0)->getType()); + if (!VT) + return false; + + unsigned NumSrcElems = VT->getNumElements(); + IRBuilder<> Builder(&EEI); + Scatterer Op0 = scatter(&EEI, EEI.getOperand(0)); + Value *ExtIdx = EEI.getOperand(1); + if (isa(ExtIdx)) + return false; + + Value *Res = UndefValue::get(VT->getElementType()); + for (unsigned I = 0; I < NumSrcElems; ++I) { + Res = Builder.CreateSelect( + Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I), + ExtIdx->getName() + ".is." + Twine(I)), + Op0[I], Res, EEI.getName() + ".upto" + Twine(I)); + } + gather(&EEI, {Res}); + return true; +} + bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { VectorType *VT = dyn_cast(SVI.getType()); if (!VT) @@ -891,16 +915,20 @@ if (!Op->use_empty()) { // The value is still needed, so recreate it using a series of // InsertElements. - auto *Ty = cast(Op->getType()); - Value *Res = UndefValue::get(Ty); - BasicBlock *BB = Op->getParent(); - unsigned Count = Ty->getNumElements(); - IRBuilder<> Builder(Op); - if (isa(Op)) - Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); - for (unsigned I = 0; I < Count; ++I) - Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), - Op->getName() + ".upto" + Twine(I)); + Value *Res = UndefValue::get(Op->getType()); + if (auto *Ty = dyn_cast(Op->getType())) { + BasicBlock *BB = Op->getParent(); + unsigned Count = Ty->getNumElements(); + IRBuilder<> Builder(Op); + if (isa(Op)) + Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); + for (unsigned I = 0; I < Count; ++I) + Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), + Op->getName() + ".upto" + Twine(I)); + } else { + assert(CV.size() == 1 && Op->getType() == CV[0]->getType()); + Res = CV[0]; + } Res->takeName(Op); Op->replaceAllUsesWith(Res); } diff --git a/llvm/test/Transforms/Scalarizer/basic.ll b/llvm/test/Transforms/Scalarizer/basic.ll --- a/llvm/test/Transforms/Scalarizer/basic.ll +++ b/llvm/test/Transforms/Scalarizer/basic.ll @@ -577,6 +577,36 @@ ret <2 x float> %res } +; Test that variable extracts are scalarized. +define i32 @f23(<4 x i32> *%src, i32 %index) { +; CHECK-LABEL: @f23( +; CHECK: %src.i0 = bitcast <4 x i32>* %src to i32* +; CHECK: %val0.i0 = load i32, i32* %src.i0, align 16 +; CHECK: %src.i1 = getelementptr i32, i32* %src.i0, i32 1 +; CHECK: %val0.i1 = load i32, i32* %src.i1, align 4 +; CHECK: %src.i2 = getelementptr i32, i32* %src.i0, i32 2 +; CHECK: %val0.i2 = load i32, i32* %src.i2, align 8 +; CHECK: %src.i3 = getelementptr i32, i32* %src.i0, i32 3 +; CHECK: %val0.i3 = load i32, i32* %src.i3, align 4 +; CHECK: %val1.i0 = shl i32 1, %val0.i0 +; CHECK: %val1.i1 = shl i32 2, %val0.i1 +; CHECK: %val1.i2 = shl i32 3, %val0.i2 +; CHECK: %val1.i3 = shl i32 4, %val0.i3 +; CHECK: %index.is.0 = icmp eq i32 %index, 0 +; CHECK: %val2.upto0 = select i1 %index.is.0, i32 %val1.i0, i32 undef +; CHECK: %index.is.1 = icmp eq i32 %index, 1 +; CHECK: %val2.upto1 = select i1 %index.is.1, i32 %val1.i1, i32 %val2.upto0 +; CHECK: %index.is.2 = icmp eq i32 %index, 2 +; CHECK: %val2.upto2 = select i1 %index.is.2, i32 %val1.i2, i32 %val2.upto1 +; CHECK: %index.is.3 = icmp eq i32 %index, 3 +; CHECK: %val2 = select i1 %index.is.3, i32 %val1.i3, i32 %val2.upto2 +; CHECK: ret i32 %val2 + %val0 = load <4 x i32> , <4 x i32> *%src + %val1 = shl <4 x i32> , %val0 + %val2 = extractelement <4 x i32> %val1, i32 %index + ret i32 %val2 +} + !0 = !{ !"root" } !1 = !{ !"set1", !0 } !2 = !{ !"set2", !0 }