diff --git a/llvm/lib/Target/X86/X86PartialReduction.cpp b/llvm/lib/Target/X86/X86PartialReduction.cpp --- a/llvm/lib/Target/X86/X86PartialReduction.cpp +++ b/llvm/lib/Target/X86/X86PartialReduction.cpp @@ -49,11 +49,8 @@ } private: - bool tryMAddPattern(BinaryOperator *BO); - bool tryMAddReplacement(Value *Op, BinaryOperator *Add); - - bool trySADPattern(BinaryOperator *BO); - bool trySADReplacement(Value *Op, BinaryOperator *Add); + bool tryMAddReplacement(Instruction *Op); + bool trySADReplacement(Instruction *Op); }; } @@ -66,139 +63,24 @@ INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, "X86 Partial Reduction", false, false) -static bool isVectorReductionOp(const BinaryOperator &BO) { - if (!BO.getType()->isVectorTy()) +bool X86PartialReduction::tryMAddReplacement(Instruction *Op) { + if (!ST->hasSSE2()) return false; - unsigned Opcode = BO.getOpcode(); - - switch (Opcode) { - case Instruction::Add: - case Instruction::Mul: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - break; - case Instruction::FAdd: - case Instruction::FMul: - if (auto *FPOp = dyn_cast(&BO)) - if (FPOp->getFastMathFlags().isFast()) - break; - LLVM_FALLTHROUGH; - default: + // Need at least 8 elements. + if (cast(Op->getType())->getNumElements() < 8) return false; - } - unsigned ElemNum = cast(BO.getType())->getNumElements(); - // Ensure the reduction size is a power of 2. - if (!isPowerOf2_32(ElemNum)) + // Element type should be i32. + if (!cast(Op->getType())->getElementType()->isIntegerTy(32)) return false; - unsigned ElemNumToReduce = ElemNum; - - // Do DFS search on the def-use chain from the given instruction. We only - // allow four kinds of operations during the search until we reach the - // instruction that extracts the first element from the vector: - // - // 1. The reduction operation of the same opcode as the given instruction. - // - // 2. PHI node. - // - // 3. ShuffleVector instruction together with a reduction operation that - // does a partial reduction. - // - // 4. ExtractElement that extracts the first element from the vector, and we - // stop searching the def-use chain here. - // - // 3 & 4 above perform a reduction on all elements of the vector. We push defs - // from 1-3 to the stack to continue the DFS. The given instruction is not - // a reduction operation if we meet any other instructions other than those - // listed above. - - SmallVector UsersToVisit{&BO}; - SmallPtrSet Visited; - bool ReduxExtracted = false; - - while (!UsersToVisit.empty()) { - auto User = UsersToVisit.back(); - UsersToVisit.pop_back(); - if (!Visited.insert(User).second) - continue; - - for (const auto *U : User->users()) { - auto *Inst = dyn_cast(U); - if (!Inst) - return false; - - if (Inst->getOpcode() == Opcode || isa(U)) { - if (auto *FPOp = dyn_cast(Inst)) - if (!isa(FPOp) && !FPOp->getFastMathFlags().isFast()) - return false; - UsersToVisit.push_back(U); - } else if (auto *ShufInst = dyn_cast(U)) { - // Detect the following pattern: A ShuffleVector instruction together - // with a reduction that do partial reduction on the first and second - // ElemNumToReduce / 2 elements, and store the result in - // ElemNumToReduce / 2 elements in another vector. - - unsigned ResultElements = ShufInst->getType()->getNumElements(); - if (ResultElements < ElemNum) - return false; - - if (ElemNumToReduce == 1) - return false; - if (!isa(U->getOperand(1))) - return false; - for (unsigned i = 0; i < ElemNumToReduce / 2; ++i) - if (ShufInst->getMaskValue(i) != int(i + ElemNumToReduce / 2)) - return false; - for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i) - if (ShufInst->getMaskValue(i) != -1) - return false; - - // There is only one user of this ShuffleVector instruction, which - // must be a reduction operation. - if (!U->hasOneUse()) - return false; - - auto *U2 = dyn_cast(*U->user_begin()); - if (!U2 || U2->getOpcode() != Opcode) - return false; - - // Check operands of the reduction operation. - if ((U2->getOperand(0) == U->getOperand(0) && U2->getOperand(1) == U) || - (U2->getOperand(1) == U->getOperand(0) && U2->getOperand(0) == U)) { - UsersToVisit.push_back(U2); - ElemNumToReduce /= 2; - } else - return false; - } else if (isa(U)) { - // At this moment we should have reduced all elements in the vector. - if (ElemNumToReduce != 1) - return false; - - auto *Val = dyn_cast(U->getOperand(1)); - if (!Val || !Val->isZero()) - return false; - - ReduxExtracted = true; - } else - return false; - } - } - return ReduxExtracted; -} - -bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) { - BasicBlock *BB = Add->getParent(); - - auto *BO = dyn_cast(Op); - if (!BO || BO->getOpcode() != Instruction::Mul || !BO->hasOneUse() || - BO->getParent() != BB) + auto *Mul = dyn_cast(Op); + if (!Mul || Mul->getOpcode() != Instruction::Mul) return false; - Value *LHS = BO->getOperand(0); - Value *RHS = BO->getOperand(1); + Value *LHS = Mul->getOperand(0); + Value *RHS = Mul->getOperand(1); // LHS and RHS should be only used once or if they are the same then only // used twice. Only check this when SSE4.1 is enabled and we have zext/sext @@ -219,7 +101,7 @@ auto CanShrinkOp = [&](Value *Op) { auto IsFreeTruncation = [&](Value *Op) { if (auto *Cast = dyn_cast(Op)) { - if (Cast->getParent() == BB && + if (Cast->getParent() == Mul->getParent() && (Cast->getOpcode() == Instruction::SExt || Cast->getOpcode() == Instruction::ZExt) && Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16) @@ -232,16 +114,16 @@ // If the operation can be freely truncated and has enough sign bits we // can shrink. if (IsFreeTruncation(Op) && - ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) + ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) return true; // SelectionDAG has limited support for truncating through an add or sub if // the inputs are freely truncatable. if (auto *BO = dyn_cast(Op)) { - if (BO->getParent() == BB && + if (BO->getParent() == Mul->getParent() && IsFreeTruncation(BO->getOperand(0)) && IsFreeTruncation(BO->getOperand(1)) && - ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) + ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) return true; } @@ -252,7 +134,7 @@ if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) return false; - IRBuilder<> Builder(Add); + IRBuilder<> Builder(Mul); auto *MulTy = cast(Op->getType()); unsigned NumElts = MulTy->getNumElements(); @@ -266,8 +148,11 @@ EvenMask[i] = i * 2; OddMask[i] = i * 2 + 1; } - Value *EvenElts = Builder.CreateShuffleVector(BO, BO, EvenMask); - Value *OddElts = Builder.CreateShuffleVector(BO, BO, OddMask); + // Creating a new mul so the replaceAllUsesWith below doesn't replace the + // uses in the shuffles we're creating. + Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1)); + Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask); + Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask); Value *MAdd = Builder.CreateAdd(EvenElts, OddElts); // Concatenate zeroes to extend back to the original type. @@ -276,34 +161,21 @@ Value *Zero = Constant::getNullValue(MAdd->getType()); Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask); - // Replaces the use of mul in the original Add with the pmaddwd and zeroes. - Add->replaceUsesOfWith(BO, Concat); - Add->setHasNoSignedWrap(false); - Add->setHasNoUnsignedWrap(false); + Mul->replaceAllUsesWith(Concat); + Mul->eraseFromParent(); return true; } -// Try to replace operans of this add with pmaddwd patterns. -bool X86PartialReduction::tryMAddPattern(BinaryOperator *BO) { +bool X86PartialReduction::trySADReplacement(Instruction *Op) { if (!ST->hasSSE2()) return false; - // Need at least 8 elements. - if (cast(BO->getType())->getNumElements() < 8) - return false; - - // Element type should be i32. - if (!cast(BO->getType())->getElementType()->isIntegerTy(32)) + // TODO: There's nothing special about i32, any integer type above i16 should + // work just as well. + if (!cast(Op->getType())->getElementType()->isIntegerTy(32)) return false; - bool Changed = false; - Changed |= tryMAddReplacement(BO->getOperand(0), BO); - Changed |= tryMAddReplacement(BO->getOperand(1), BO); - return Changed; -} - -bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) { // Operand should be a select. auto *SI = dyn_cast(Op); if (!SI) @@ -337,7 +209,7 @@ if (!Op0 || !Op1) return false; - IRBuilder<> Builder(Add); + IRBuilder<> Builder(SI); auto *OpTy = cast(Op->getType()); unsigned NumElts = OpTy->getNumElements(); @@ -355,7 +227,7 @@ IntrinsicNumElts = 16; } - Function *PSADBWFn = Intrinsic::getDeclaration(Add->getModule(), IID); + Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID); if (NumElts < 16) { // Pad input with zeroes. @@ -419,27 +291,155 @@ Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); } - // Replaces the uses of Op in Add with the new sequence. - Add->replaceUsesOfWith(Op, Ops[0]); - Add->setHasNoSignedWrap(false); - Add->setHasNoUnsignedWrap(false); + SI->replaceAllUsesWith(Ops[0]); + SI->eraseFromParent(); return true; } -bool X86PartialReduction::trySADPattern(BinaryOperator *BO) { - if (!ST->hasSSE2()) - return false; +// Walk backwards from the ExtractElementInst and determine if it is the end of +// a horizontal reduction. Return the input to the reduction if we find one. +static Value *matchAddReduction(const ExtractElementInst &EE) { + // Make sure we're extracting index 0. + auto *Index = dyn_cast(EE.getIndexOperand()); + if (!Index || !Index->isNullValue()) + return nullptr; - // TODO: There's nothing special about i32, any integer type above i16 should - // work just as well. - if (!cast(BO->getType())->getElementType()->isIntegerTy(32)) + const auto *BO = dyn_cast(EE.getVectorOperand()); + if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) + return nullptr; + + unsigned NumElems = cast(BO->getType())->getNumElements(); + // Ensure the reduction size is a power of 2. + if (!isPowerOf2_32(NumElems)) + return nullptr; + + const Value *Op = BO; + unsigned Stages = Log2_32(NumElems); + for (unsigned i = 0; i != Stages; ++i) { + const auto *BO = dyn_cast(Op); + if (!BO || BO->getOpcode() != Instruction::Add) + return nullptr; + + // If this isn't the first add, then it should only have 2 users, the + // shuffle and another add which we checked in the previous iteration. + if (i != 0 && !BO->hasNUses(2)) + return nullptr; + + Value *LHS = BO->getOperand(0); + Value *RHS = BO->getOperand(1); + + auto *Shuffle = dyn_cast(LHS); + if (Shuffle) { + Op = RHS; + } else { + Shuffle = dyn_cast(RHS); + Op = LHS; + } + + // The first operand of the shuffle should be the same as the other operand + // of the bin op. + if (!Shuffle || Shuffle->getOperand(0) != Op) + return nullptr; + + // Verify the shuffle has the expected (at this stage of the pyramid) mask. + unsigned MaskEnd = 1 << i; + for (unsigned Index = 0; Index < MaskEnd; ++Index) + if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index)) + return nullptr; + } + + return const_cast(Op); +} + +// See if this BO is reachable from this Phi by walking forward through single +// use BinaryOperators with the same opcode. If we get back then we know we've +// found a loop and it is safe to step through this Add to find more leaves. +static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) { + // The PHI itself should only have one use. + if (!Phi->hasOneUse()) return false; - bool Changed = false; - Changed |= trySADReplacement(BO->getOperand(0), BO); - Changed |= trySADReplacement(BO->getOperand(1), BO); - return Changed; + Instruction *U = cast(*Phi->user_begin()); + if (U == BO) + return true; + + while (U->hasOneUse() && U->getOpcode() == BO->getOpcode()) + U = cast(*U->user_begin()); + + return U == BO; +} + +// Collect all the leaves of the tree of adds that feeds into the horizontal +// reduction. Root is the Value that is used by the horizontal reduction. +// We look through single use phis, single use adds, or adds that are used by +// a phi that forms a loop with the add. +static void collectLeaves(Value *Root, SmallVectorImpl &Leaves) { + SmallPtrSet Visited; + SmallVector Worklist; + Worklist.push_back(Root); + + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + if (auto *PN = dyn_cast(V)) { + // PHI node should have single use unless it is the root node, then it + // has 2 uses. + if (!PN->hasNUses(PN == Root ? 2 : 1)) + break; + + // Push incoming values to the worklist. + for (Value *InV : PN->incoming_values()) + Worklist.push_back(InV); + + continue; + } + + if (auto *BO = dyn_cast(V)) { + if (BO->getOpcode() == Instruction::Add) { + // Simple case. Single use, just push its operands to the worklist. + if (BO->hasNUses(BO == Root ? 2 : 1)) { + for (Value *Op : BO->operands()) + Worklist.push_back(Op); + continue; + } + + // If there is additional use, make sure it is an unvisited phi that + // gets us back to this node. + if (BO->hasNUses(BO == Root ? 3 : 2)) { + PHINode *PN = nullptr; + for (auto *U : Root->users()) + if (auto *P = dyn_cast(U)) + if (!Visited.count(P)) + PN = P; + + // If we didn't find a 2-input PHI then this isn't a case we can + // handle. + if (!PN || PN->getNumIncomingValues() != 2) + continue; + + // Walk forward from this phi to see if it reaches back to this add. + if (!isReachableFromPHI(PN, BO)) + continue; + + // The phi forms a loop with this Add, push its operands. + for (Value *Op : BO->operands()) + Worklist.push_back(Op); + } + } + } + + // Not an add or phi, make it a leaf. + if (auto *I = dyn_cast(V)) { + if (!V->hasNUses(I == Root ? 2 : 1)) + continue; + + // Add this as a leaf. + Leaves.push_back(I); + } + } } bool X86PartialReduction::runOnFunction(Function &F) { @@ -458,22 +458,29 @@ bool MadeChange = false; for (auto &BB : F) { for (auto &I : BB) { - auto *BO = dyn_cast(&I); - if (!BO) + auto *EE = dyn_cast(&I); + if (!EE) continue; - if (!isVectorReductionOp(*BO)) + // First find a reduction tree. + // FIXME: Do we need to handle other opcodes than Add? + Value *Root = matchAddReduction(*EE); + if (!Root) continue; - if (BO->getOpcode() == Instruction::Add) { - if (tryMAddPattern(BO)) { + SmallVector Leaves; + collectLeaves(Root, Leaves); + + for (Instruction *I : Leaves) { + if (tryMAddReplacement(I)) { MadeChange = true; continue; } - if (trySADPattern(BO)) { + + // Don't do SAD matching on the root node. SelectionDAG already + // has support for that and currently generates better code. + if (I != Root && trySADReplacement(I)) MadeChange = true; - continue; - } } } } diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll --- a/llvm/test/CodeGen/X86/madd.ll +++ b/llvm/test/CodeGen/X86/madd.ll @@ -2657,9 +2657,9 @@ ; AVX-LABEL: madd_double_reduction: ; AVX: # %bb.0: ; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vmovdqu (%rdx), %xmm1 ; AVX-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 -; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0 ; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] ; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0 @@ -2720,9 +2720,9 @@ ; AVX-NEXT: movq {{[0-9]+}}(%rsp), %r10 ; AVX-NEXT: movq {{[0-9]+}}(%rsp), %rax ; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vmovdqu (%rdx), %xmm1 ; AVX-NEXT: vpmaddwd (%rcx), %xmm1, %xmm1 -; AVX-NEXT: vpmaddwd (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vmovdqu (%r8), %xmm2 ; AVX-NEXT: vpmaddwd (%r9), %xmm2, %xmm2 ; AVX-NEXT: vpaddd %xmm2, %xmm0, %xmm0 diff --git a/llvm/test/CodeGen/X86/sad.ll b/llvm/test/CodeGen/X86/sad.ll --- a/llvm/test/CodeGen/X86/sad.ll +++ b/llvm/test/CodeGen/X86/sad.ll @@ -1061,9 +1061,9 @@ ; AVX-LABEL: sad_double_reduction: ; AVX: # %bb.0: # %bb ; AVX-NEXT: vmovdqu (%rdi), %xmm0 +; AVX-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vmovdqu (%rdx), %xmm1 ; AVX-NEXT: vpsadbw (%rcx), %xmm1, %xmm1 -; AVX-NEXT: vpsadbw (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vpaddd %xmm0, %xmm1, %xmm0 ; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] ; AVX-NEXT: vpaddd %xmm1, %xmm0, %xmm0