diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6686,6 +6686,18 @@ assert(Pair.first && "DebugLoc must be set."); ExternallyUsedValues[Pair.second].push_back(Pair.first); } + + // The compare instruction of a min/max is the insertion point for new + // instructions and may be replaced with a new compare instruction. + auto getCmpForMinMaxReduction = [](Instruction *RdxRootInst) { + assert(isa(RdxRootInst) && + "Expected min/max reduction to have select root instruction"); + Value *ScalarCond = cast(RdxRootInst)->getCondition(); + assert(isa(ScalarCond) && + "Expected min/max reduction to have compare condition"); + return cast(ScalarCond); + }; + // The reduction root is used as the insertion point for new instructions, // so set it as externally used to prevent it from being deleted. ExternallyUsedValues[ReductionRoot]; @@ -6741,15 +6753,6 @@ DebugLoc Loc = cast(ReducedVals[i])->getDebugLoc(); Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues); - auto getCmpForMinMaxReduction = [](Instruction *RdxRootInst) { - assert(isa(RdxRootInst) && - "Expected min/max reduction to have select root instruction"); - Value *ScalarCond = cast(RdxRootInst)->getCondition(); - assert(isa(ScalarCond) && - "Expected min/max reduction to have compare condition"); - return cast(ScalarCond); - }; - // Emit a reduction. For min/max, the root is a select, but the insertion // point is the compare condition of that select. Instruction *RdxRootInst = cast(ReductionRoot); @@ -6793,8 +6796,20 @@ VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I); } } - // Update users. + + // Update users. For a min/max reduction that ends with a compare and + // select, we also have to RAUW for the compare instruction feeding the + // reduction root. That's because the original compare may have extra uses + // besides the final select of the reduction. + if (ReductionData.isMinMax()) { + if (auto *VecSelect = dyn_cast(VectorizedTree)) { + Instruction *ScalarCmp = + getCmpForMinMaxReduction(cast(ReductionRoot)); + ScalarCmp->replaceAllUsesWith(VecSelect->getCondition()); + } + } ReductionRoot->replaceAllUsesWith(VectorizedTree); + // Mark all scalar reduction ops for deletion, they are replaced by the // vector reductions. V.eraseInstructions(IgnoreList); diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reduction.ll b/llvm/test/Transforms/SLPVectorizer/X86/reduction.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/reduction.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/reduction.ll @@ -91,7 +91,7 @@ ; CHECK-NEXT: [[TMP5:%.*]] = select i1 [[TMP4]], i32 [[TMP3]], i32 [[T4]] ; CHECK-NEXT: [[C012345:%.*]] = icmp sgt i32 [[TMP5]], [[T5]] ; CHECK-NEXT: [[T17:%.*]] = select i1 [[C012345]], i32 [[TMP5]], i32 [[T5]] -; CHECK-NEXT: [[THREE_OR_FOUR:%.*]] = select i1 undef, i32 3, i32 4 +; CHECK-NEXT: [[THREE_OR_FOUR:%.*]] = select i1 [[TMP4]], i32 3, i32 4 ; CHECK-NEXT: store i32 [[THREE_OR_FOUR]], i32* [[P:%.*]], align 8 ; CHECK-NEXT: ret i32 [[T17]] ; @@ -137,7 +137,7 @@ ; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i32> [[RDX_MINMAX_SELECT]], i32 0 ; CHECK-NEXT: [[TMP4:%.*]] = icmp sgt i32 [[TMP3]], 0 ; CHECK-NEXT: [[OP_EXTRA:%.*]] = select i1 [[TMP4]], i32 [[TMP3]], i32 0 -; CHECK-NEXT: [[SPEC_STORE_SELECT87:%.*]] = zext i1 undef to i32 +; CHECK-NEXT: [[SPEC_STORE_SELECT87:%.*]] = zext i1 [[TMP4]] to i32 ; CHECK-NEXT: [[CMP23_2:%.*]] = icmp sgt i32 [[SPEC_STORE_SELECT87]], [[OP_EXTRA]] ; CHECK-NEXT: ret i1 [[CMP23_2]] ; diff --git a/llvm/test/Transforms/SLPVectorizer/X86/used-reduced-op.ll b/llvm/test/Transforms/SLPVectorizer/X86/used-reduced-op.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/used-reduced-op.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/used-reduced-op.ll @@ -71,7 +71,7 @@ ; CHECK-NEXT: [[NEG_1_1:%.*]] = sub nsw i32 0, [[SUB_1_1]] ; CHECK-NEXT: [[TMP46:%.*]] = select i1 [[TMP45]], i32 [[NEG_1_1]], i32 [[SUB_1_1]] ; CHECK-NEXT: [[CMP12_1_1:%.*]] = icmp slt i32 [[TMP46]], [[OP_EXTRA]] -; CHECK-NEXT: [[NARROW:%.*]] = or i1 [[CMP12_1_1]], undef +; CHECK-NEXT: [[NARROW:%.*]] = or i1 [[CMP12_1_1]], [[TMP44]] ; CHECK-NEXT: [[SPEC_SELECT8_1_1:%.*]] = select i1 [[CMP12_1_1]], i32 [[TMP46]], i32 [[OP_EXTRA]] ; CHECK-NEXT: [[SUB_2_1:%.*]] = sub i32 [[TMP30]], [[TMP3]] ; CHECK-NEXT: [[TMP47:%.*]] = icmp slt i32 [[SUB_2_1]], 0