diff --git a/llvm/include/llvm/Transforms/Vectorize/VectorCombine.h b/llvm/include/llvm/Transforms/Vectorize/VectorCombine.h --- a/llvm/include/llvm/Transforms/Vectorize/VectorCombine.h +++ b/llvm/include/llvm/Transforms/Vectorize/VectorCombine.h @@ -20,10 +20,16 @@ namespace llvm { /// Optimize scalar/vector interactions in IR using target cost models. -struct VectorCombinePass : public PassInfoMixin { +class VectorCombinePass : public PassInfoMixin { + /// If true only perform scalarization combines and do not introduce new + /// vector operations. + bool ScalarizationOnly; + public: + VectorCombinePass(bool ScalarizationOnly = false) + : ScalarizationOnly(ScalarizationOnly) {} + PreservedAnalyses run(Function &F, FunctionAnalysisManager &); }; - } #endif // LLVM_TRANSFORMS_VECTORIZE_VECTORCOMBINE_H diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp --- a/llvm/lib/Passes/PassBuilderPipelines.cpp +++ b/llvm/lib/Passes/PassBuilderPipelines.cpp @@ -503,7 +503,7 @@ // The matrix extension can introduce large vector operations early, which can // benefit from running vector-combine early on. if (EnableMatrix) - FPM.addPass(VectorCombinePass()); + FPM.addPass(VectorCombinePass(/*ScalarizationOnly=*/true)); // Eliminate redundancies. FPM.addPass(MergedLoadStoreMotionPass()); diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -63,8 +63,10 @@ class VectorCombine { public: VectorCombine(Function &F, const TargetTransformInfo &TTI, - const DominatorTree &DT, AAResults &AA, AssumptionCache &AC) - : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC) {} + const DominatorTree &DT, AAResults &AA, AssumptionCache &AC, + bool ScalarizationOnly) + : F(F), Builder(F.getContext()), TTI(TTI), DT(DT), AA(AA), AC(AC), + ScalarizationOnly(ScalarizationOnly) {} bool run(); @@ -75,6 +77,11 @@ const DominatorTree &DT; AAResults &AA; AssumptionCache &AC; + + /// If true only perform scalarization combines and do not introduce new + /// vector operations. + bool ScalarizationOnly; + InstructionWorklist Worklist; bool vectorizeLoadInsert(Instruction &I); @@ -1071,11 +1078,13 @@ bool MadeChange = false; auto FoldInst = [this, &MadeChange](Instruction &I) { Builder.SetInsertPoint(&I); - MadeChange |= vectorizeLoadInsert(I); - MadeChange |= foldExtractExtract(I); - MadeChange |= foldBitcastShuf(I); + if (!ScalarizationOnly) { + MadeChange |= vectorizeLoadInsert(I); + MadeChange |= foldExtractExtract(I); + MadeChange |= foldBitcastShuf(I); + MadeChange |= foldExtractedCmps(I); + } MadeChange |= scalarizeBinopOrCmp(I); - MadeChange |= foldExtractedCmps(I); MadeChange |= scalarizeLoadExtract(I); MadeChange |= foldSingleElementStore(I); }; @@ -1137,7 +1146,7 @@ auto &TTI = getAnalysis().getTTI(F); auto &DT = getAnalysis().getDomTree(); auto &AA = getAnalysis().getAAResults(); - VectorCombine Combiner(F, TTI, DT, AA, AC); + VectorCombine Combiner(F, TTI, DT, AA, AC, false); return Combiner.run(); } }; @@ -1161,7 +1170,7 @@ TargetTransformInfo &TTI = FAM.getResult(F); DominatorTree &DT = FAM.getResult(F); AAResults &AA = FAM.getResult(F); - VectorCombine Combiner(F, TTI, DT, AA, AC); + VectorCombine Combiner(F, TTI, DT, AA, AC, ScalarizationOnly); if (!Combiner.run()) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll --- a/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll +++ b/llvm/test/Transforms/PhaseOrdering/AArch64/matrix-extract-insert.ll @@ -308,18 +308,16 @@ define <4 x float> @reverse_hadd_v4f32(<4 x float> %a, <4 x float> %b) { ; CHECK-LABEL: @reverse_hadd_v4f32( -; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> poison, <4 x i32> -; CHECK-NEXT: [[TMP1:%.*]] = fadd <4 x float> [[SHIFT]], [[A]] -; CHECK-NEXT: [[SHIFT1:%.*]] = shufflevector <4 x float> [[A]], <4 x float> poison, <4 x i32> -; CHECK-NEXT: [[TMP2:%.*]] = fadd <4 x float> [[SHIFT1]], [[A]] -; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x float> [[TMP1]], <4 x float> [[TMP2]], <4 x i32> -; CHECK-NEXT: [[SHIFT2:%.*]] = shufflevector <4 x float> [[B:%.*]], <4 x float> poison, <4 x i32> -; CHECK-NEXT: [[TMP4:%.*]] = fadd <4 x float> [[SHIFT2]], [[B]] -; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[TMP3]], <4 x float> [[TMP4]], <4 x i32> -; CHECK-NEXT: [[SHIFT3:%.*]] = shufflevector <4 x float> [[B]], <4 x float> poison, <4 x i32> -; CHECK-NEXT: [[TMP6:%.*]] = fadd <4 x float> [[SHIFT3]], [[B]] -; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x float> [[TMP5]], <4 x float> [[TMP6]], <4 x i32> -; CHECK-NEXT: ret <4 x float> [[TMP7]] +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A:%.*]], <4 x float> undef, <2 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x float> [[A]], <4 x float> undef, <2 x i32> +; CHECK-NEXT: [[TMP3:%.*]] = fadd <2 x float> [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x float> [[TMP3]], <2 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x float> [[B:%.*]], <4 x float> undef, <2 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x float> [[B]], <4 x float> undef, <2 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = fadd <2 x float> [[TMP5]], [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x float> [[TMP7]], <2 x float> poison, <4 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <4 x float> [[TMP8]], <4 x float> [[TMP4]], <4 x i32> +; CHECK-NEXT: ret <4 x float> [[TMP9]] ; %vecext = extractelement <4 x float> %a, i32 0 %vecext1 = extractelement <4 x float> %a, i32 1