Index: lib/Transforms/Scalar/LoopStrengthReduce.cpp =================================================================== --- lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1697,6 +1697,12 @@ bool FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse); ICmpInst *OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse); void OptimizeLoopTermCond(); + SmallVector *FindUnrolledReduction(Instruction *Inst, + PHINode *Phi, + bool first); + void TransformUnrolledReduction(SmallVector *Chain, + PHINode *Phi); + void OptimizeUnrolledReductions(); void ChainInstruction(Instruction *UserInst, Instruction *IVOper, SmallVectorImpl &ChainUsersVec); @@ -2236,6 +2242,237 @@ } } +/// FindUnrolledReduction - Return chain of instructions corresponding to +/// unrolled iterations of a reduction in a loop, if any. +SmallVector * +LSRInstance::FindUnrolledReduction(Instruction *Inst, PHINode *Phi, + bool first = true) { + + // Unsafe algebra would be required in order to reorder FP operations. + if (isa(Inst) && !Inst->getFastMathFlags().unsafeAlgebra()) + return NULL; + + // Each value computed in the chain should be used exactly once, except for + // the final reduction variable which can be arbitrarily used at loop exit. + if (first) { + for (User *U : Inst->users()) + if (isa(U) && U != Phi && + cast(U)->getParent() == Inst->getParent()) + return NULL; + } + else if (Inst->getNumUses() > 1) + return NULL; + + for (unsigned int i = 0; i < Inst->getNumOperands(); ++i) { + Value *Operand = Inst->getOperand(i); + + // Subtractions are not commutative, only consider first operand. + if ((Inst->getOpcode() == Instruction::Sub || + Inst->getOpcode() == Instruction::FSub) && i > 0) + break; + + // Walk through the dependency chain as long as the same opcode is + // encountered. + if (isa(Operand)) { + BinaryOperator *BinaryOp = cast(Operand); + + if (BinaryOp->getOpcode() == Inst->getOpcode()) { + SmallVector *Chain = + FindUnrolledReduction(BinaryOp, Phi, false); + + if (Chain) { + Chain->push_back(Inst); + return Chain; + } + } + } + + // Specified phi node was encountered, which means we have a dependency + // chain. + else if (isa(Operand)) { + if (Phi == cast(Operand)) { + SmallVector *Chain = + new SmallVector(); + Chain->push_back(Inst); + return Chain; + } + } + } + return NULL; +} + +/// TransformUnrolledReduction - Break dependencies in the chain. +void +LSRInstance::TransformUnrolledReduction(SmallVector *Chain, + PHINode *Phi) { + assert((Chain && !Chain->empty()) && "empty chains are not allowed"); + BasicBlock *LoopBlock = Phi->getParent(); + BasicBlock *LoopExit = LoopBlock->getNextNode(); + + // Retrieve register used for reduction. + Value *Val = Phi->getIncomingValueForBlock(LoopBlock); + assert((isa(Val)) && "unexpected dependency chain"); + Instruction *Red = cast(Val); + + // Opcode-specific parameters. + Instruction::BinaryOps Opcode = (Instruction::BinaryOps)(Red->getOpcode()); + int neutral = 0; + + if (Opcode == Instruction::And || + Opcode == Instruction::Mul || + Opcode == Instruction::FMul) + neutral = 1; + + if (Opcode == Instruction::Sub) + Opcode = Instruction::Add; + else if (Opcode == Instruction::FSub) + Opcode = Instruction::FAdd; + + // Retrieve reduction at loop exit. + bool renamed = false; + for (User *U : Red->users()) { + if (isa(U) && cast(U)->getParent() == LoopExit) { + Red = cast(U); + renamed = true; + break; + } + } + + // We will insert code at the loop exit to compute a reduction based on many + // reduction variables, one for each element in the chain (i.e. as many as + // the unroll factor for the loop). + Instruction *InsertionPt = LoopExit->getFirstNonPHI(); + Instruction *Previous = Phi; + Instruction *NewRed = Red, *FirstRed = NULL; + + for (Instruction *Inst : *Chain) { + bool isLast = (Phi->getIncomingValueForBlock(LoopBlock) == Inst); + Twine Name = Inst->getName(); + + PHINode *NewPhi; + if (!isLast) { + // Create new reduction variable. + NewPhi = PHINode::Create(Inst->getType(), 2, Name + ".acc", + LoopBlock->begin()); + + Constant *Zero = NULL; + if (Inst->getType()->isFloatingPointTy()) + Zero = ConstantFP::get(Inst->getType(), neutral); + else + Zero = ConstantInt::get(Inst->getType(), neutral, false); + + NewPhi->addIncoming(Zero, LoopBlock->getPrevNode()); + NewPhi->addIncoming(Inst, LoopBlock); + } + else + NewPhi = Phi; // Reuse existing reduction variable. + + // Break dependencies using the new reduction variable. + bool found = false; + for (unsigned int i = 0; i < Inst->getNumOperands(); ++i) { + if (Inst->getOperand(i) == Previous) { + Inst->setOperand(i, NewPhi); + found = true; + break; + } + } + + assert(found && "unexpected depedency chain"); + + if (!isLast) { + // Reduce all reduction variables at loop exit. + NewPhi = PHINode::Create(Inst->getType(), 1, Name + ".phi", + LoopExit->begin()); + NewPhi->addIncoming(Inst, LoopBlock); + + NewRed = BinaryOperator::Create(Opcode, NewPhi, NewRed, Name + ".red", + InsertionPt); + + // Remember the first operation as we will replace one of its uses (see + // below). + if (FirstRed == NULL) + FirstRed = NewRed; + } + + Previous = Inst; + } + + // Replace uses of the original reduction variable with uses of the new one. + Red->replaceAllUsesWith(NewRed); + FirstRed->setOperand(1, Red); + if (!renamed) + for (unsigned int i = 0; i < Phi->getNumIncomingValues(); ++i) + if (Phi->getIncomingValue(i) == NewRed) { + Phi->setIncomingValue(i, Red); + break; + } +} + + +/// OptimizeUnrolledReductions - Break dependencies between unrolled iterations +/// of reductions in loops. This should be particularly effective for +/// superscalar targets. +void +LSRInstance::OptimizeUnrolledReductions() { + + std::vector::const_iterator BBIter, BBEnd; + + for (BBIter = L->getBlocks().begin(), BBEnd = L->getBlocks().end(); + BBIter != BBEnd; ++BBIter) { + SmallVector*, 2> Chains; + SmallVector PhiNodes; + + // Search for phi nodes. + for (BasicBlock::iterator I = (*BBIter)->begin(), E = (*BBIter)->end(); + I != E; ++I) { + if (isa(I)) { + for (unsigned int i = 0; i < I->getNumOperands(); ++i) { + Value *Operand = I->getOperand(i); + + // Search for binary operators used within phi nodes. + if (isa(Operand)) { + BinaryOperator *BinaryOp = cast(Operand); + + // Must be in same basic block. + if (BinaryOp->getParent() == I->getParent()) { + + // Look for interesting opcodes. + if (BinaryOp->getOpcode() == Instruction::FAdd || + BinaryOp->getOpcode() == Instruction::FSub || + BinaryOp->getOpcode() == Instruction::FMul || + BinaryOp->getOpcode() == Instruction::Add || + BinaryOp->getOpcode() == Instruction::Sub || + BinaryOp->getOpcode() == Instruction::Mul || + BinaryOp->getOpcode() == Instruction::And || + BinaryOp->getOpcode() == Instruction::Or || + BinaryOp->getOpcode() == Instruction::Xor) { + PHINode *Phi = cast(I); + SmallVector* Chain = + FindUnrolledReduction(BinaryOp, Phi); + + if (Chain && Chain->size() > 1) { + // Found a dependency chain corresponding to a reduction, + // record if length greater than one (phi node excluded). + Chains.push_back(Chain); + PhiNodes.push_back(Phi); + } + } + } + } + } + } + } + + // Break dependency chains once the basic block has been processed. + while (Chains.size() > 0 && PhiNodes.size() > 0) { + SmallVector *Chain = Chains.pop_back_val(); + PHINode *Phi = PhiNodes.pop_back_val(); + TransformUnrolledReduction(Chain, Phi); + delete(Chain); + } + } +} + /// reconcileNewOffset - Determine if the given use can accommodate a fixup /// at the given offset and other details. If so, update the use and /// return true. @@ -4910,6 +5147,7 @@ // First, perform some low-level loop optimizations. OptimizeShadowIV(); OptimizeLoopTermCond(); + OptimizeUnrolledReductions(); // If loop preparation eliminates all interesting IV users, bail. if (IU.empty()) return; Index: test/Transforms/LoopStrengthReduce/X86/ivchain-X86.ll =================================================================== --- test/Transforms/LoopStrengthReduce/X86/ivchain-X86.ll +++ test/Transforms/LoopStrengthReduce/X86/ivchain-X86.ll @@ -12,7 +12,7 @@ ; X64: shlq $2 ; no other address computation in the preheader ; X64-NEXT: xorl -; X64-NEXT: .align +; X64: .align ; X64: %loop ; no complex address modes ; X64-NOT: (%{{[^)]+}},%{{[^)]+}}, Index: test/Transforms/LoopStrengthReduce/unrolled-reduction.ll =================================================================== --- test/Transforms/LoopStrengthReduce/unrolled-reduction.ll +++ test/Transforms/LoopStrengthReduce/unrolled-reduction.ll @@ -0,0 +1,174 @@ +; RUN: opt < %s -loop-reduce -S | FileCheck %s + +; CHECK-LABEL: loop: +; CHECK: .acc = phi +; CHECK-NEXT: .acc = phi +; CHECK-NEXT: .acc = phi + +; CHECK: add +; CHECK-NEXT: add +; CHECK-NEXT: add +; CHECK-NEXT: add + +; CHECK-LABEL: exit +; CHECK: .phi = phi +; CHECK-NEXT: .phi = phi +; CHECK-NEXT: .phi = phi + +; CHECK: .red = add +; CHECK-NEXT: .red = add +; CHECK-NEXT: .red = add + +define i32 @add(i32* %a, i32* %b, i32 %x) nounwind { +entry: + br label %loop +loop: + %iv = phi i32* [ %a, %entry ], [ %iv4, %loop ] + %s = phi i32 [ 0, %entry ], [ %s4, %loop ] + %v = load i32* %iv + %iv1 = getelementptr inbounds i32* %iv, i32 %x + %v1 = load i32* %iv1 + %iv2 = getelementptr inbounds i32* %iv1, i32 %x + %v2 = load i32* %iv2 + %iv3 = getelementptr inbounds i32* %iv2, i32 %x + %v3 = load i32* %iv3 + %s1 = add i32 %s, %v + %s2 = add i32 %s1, %v1 + %s3 = add i32 %s2, %v2 + %s4 = add i32 %s3, %v3 + %iv4 = getelementptr inbounds i32* %iv3, i32 %x + %cmp = icmp eq i32* %iv4, %b + br i1 %cmp, label %exit, label %loop +exit: + ret i32 %s4 +} + +; CHECK-LABEL: loop: +; CHECK: .acc = phi +; CHECK-NEXT: .acc = phi +; CHECK-NEXT: .acc = phi + +; CHECK: sub +; CHECK-NEXT: sub +; CHECK-NEXT: sub +; CHECK-NEXT: sub + +; CHECK-LABEL: exit +; CHECK: .phi = phi +; CHECK-NEXT: .phi = phi +; CHECK-NEXT: .phi = phi + +; CHECK: .red = add +; CHECK-NEXT: .red = add +; CHECK-NEXT: .red = add + +define i32 @sub(i32* %a, i32* %b, i32 %x) nounwind { +entry: + br label %loop +loop: + %iv = phi i32* [ %a, %entry ], [ %iv4, %loop ] + %s = phi i32 [ 0, %entry ], [ %s4, %loop ] + %v = load i32* %iv + %iv1 = getelementptr inbounds i32* %iv, i32 %x + %v1 = load i32* %iv1 + %iv2 = getelementptr inbounds i32* %iv1, i32 %x + %v2 = load i32* %iv2 + %iv3 = getelementptr inbounds i32* %iv2, i32 %x + %v3 = load i32* %iv3 + %s1 = sub i32 %s, %v + %s2 = sub i32 %s1, %v1 + %s3 = sub i32 %s2, %v2 + %s4 = sub i32 %s3, %v3 + %iv4 = getelementptr inbounds i32* %iv3, i32 %x + %cmp = icmp eq i32* %iv4, %b + br i1 %cmp, label %exit, label %loop +exit: + ret i32 %s4 +} + +; CHECK-LABEL: loop: +; CHECK: .acc = phi +; CHECK-NEXT: .acc = phi +; CHECK-NEXT: .acc = phi + +; CHECK: fadd +; CHECK-NEXT: fadd +; CHECK-NEXT: fadd +; CHECK-NEXT: fadd + +; CHECK-LABEL: exit +; CHECK: .phi = phi +; CHECK-NEXT: .phi = phi +; CHECK-NEXT: .phi = phi + +; CHECK: .red = fadd +; CHECK-NEXT: .red = fadd +; CHECK-NEXT: .red = fadd + +define float @fadd(float* %a, float* %b, i32 %x) nounwind { +entry: + br label %loop +loop: + %iv = phi float* [ %a, %entry ], [ %iv4, %loop ] + %s = phi float [ 0.0, %entry ], [ %s4, %loop ] + %v = load float* %iv + %iv1 = getelementptr inbounds float* %iv, i32 %x + %v1 = load float* %iv1 + %iv2 = getelementptr inbounds float* %iv1, i32 %x + %v2 = load float* %iv2 + %iv3 = getelementptr inbounds float* %iv2, i32 %x + %v3 = load float* %iv3 + %s1 = fadd fast float %s, %v + %s2 = fadd fast float %s1, %v1 + %s3 = fadd fast float %s2, %v2 + %s4 = fadd fast float %s3, %v3 + %iv4 = getelementptr inbounds float* %iv3, i32 %x + %cmp = icmp eq float* %iv4, %b + br i1 %cmp, label %exit, label %loop +exit: + ret float %s4 +} + +; CHECK-LABEL: loop: +; CHECK: .acc = phi +; CHECK-NEXT: .acc = phi +; CHECK-NEXT: .acc = phi + +; CHECK: fsub +; CHECK-NEXT: fsub +; CHECK-NEXT: fsub +; CHECK-NEXT: fsub + +; CHECK-LABEL: exit +; CHECK: .phi = phi +; CHECK-NEXT: .phi = phi +; CHECK-NEXT: .phi = phi + +; CHECK: .red = fadd +; CHECK-NEXT: .red = fadd +; CHECK-NEXT: .red = fadd + +define float @fsub(float* %a, float* %b, i32 %x) nounwind { +entry: + br label %loop +loop: + %iv = phi float* [ %a, %entry ], [ %iv4, %loop ] + %s = phi float [ 0.0, %entry ], [ %s4, %loop ] + %v = load float* %iv + %iv1 = getelementptr inbounds float* %iv, i32 %x + %v1 = load float* %iv1 + %iv2 = getelementptr inbounds float* %iv1, i32 %x + %v2 = load float* %iv2 + %iv3 = getelementptr inbounds float* %iv2, i32 %x + %v3 = load float* %iv3 + %s1 = fsub fast float %s, %v + %s2 = fsub fast float %s1, %v1 + %s3 = fsub fast float %s2, %v2 + %s4 = fsub fast float %s3, %v3 + %iv4 = getelementptr inbounds float* %iv3, i32 %x + %cmp = icmp eq float* %iv4, %b + br i1 %cmp, label %exit, label %loop +exit: + ret float %s4 +} +