diff --git a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp --- a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp +++ b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp @@ -149,10 +149,10 @@ bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); // Pushes the given add out of the loop void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); - // Pushes the given mul out of the loop - void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, - Value *OffsSecondOperand, unsigned LoopIncrement, - IRBuilder<> &Builder); + // Pushes the given mul or shl out of the loop + void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound, + Value *OffsSecondOperand, unsigned LoopIncrement, + IRBuilder<> &Builder); }; } // end anonymous namespace @@ -342,7 +342,8 @@ const Instruction *I = cast(V); if (I->getOpcode() == Instruction::Add || - I->getOpcode() == Instruction::Mul) { + I->getOpcode() == Instruction::Mul || + I->getOpcode() == Instruction::Shl) { Optional Op0 = getIfConst(I->getOperand(0)); Optional Op1 = getIfConst(I->getOperand(1)); if (!Op0 || !Op1) @@ -351,6 +352,8 @@ return Optional{Op0.getValue() + Op1.getValue()}; if (I->getOpcode() == Instruction::Mul) return Optional{Op0.getValue() * Op1.getValue()}; + if (I->getOpcode() == Instruction::Shl) + return Optional{Op0.getValue() << Op1.getValue()}; } return Optional{}; } @@ -888,11 +891,11 @@ Phi->removeIncomingValue(StartIndex); } -void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, - Value *IncrementPerRound, - Value *OffsSecondOperand, - unsigned LoopIncrement, - IRBuilder<> &Builder) { +void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi, + Value *IncrementPerRound, + Value *OffsSecondOperand, + unsigned LoopIncrement, + IRBuilder<> &Builder) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); // Create a new scalar add outside of the loop and transform it to a splat @@ -901,12 +904,13 @@ Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); // Create a new index - Value *StartIndex = BinaryOperator::Create( - Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), - OffsSecondOperand, "PushedOutMul", InsertionPoint); + Value *StartIndex = + BinaryOperator::Create((Instruction::BinaryOps)Opcode, + Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), + OffsSecondOperand, "PushedOutMul", InsertionPoint); Instruction *Product = - BinaryOperator::Create(Instruction::Mul, IncrementPerRound, + BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound, OffsSecondOperand, "Product", InsertionPoint); // Increment NewIndex by Product instead of the multiplication Instruction *NewIncrement = BinaryOperator::Create( @@ -936,7 +940,8 @@ return Gatscat; } else { unsigned OpCode = cast(U)->getOpcode(); - if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && + if ((OpCode == Instruction::Add || OpCode == Instruction::Mul || + OpCode == Instruction::Shl) && hasAllGatScatUsers(cast(U))) { continue; } @@ -956,7 +961,8 @@ return false; Instruction *Offs = cast(Offsets); if (Offs->getOpcode() != Instruction::Add && - Offs->getOpcode() != Instruction::Mul) + Offs->getOpcode() != Instruction::Mul && + Offs->getOpcode() != Instruction::Shl) return false; Loop *L = LI->getLoopFor(BB); if (L == nullptr) @@ -1063,8 +1069,9 @@ pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); break; case Instruction::Mul: - pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, - Builder); + case Instruction::Shl: + pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound, + OffsSecondOperand, IncrementingBlock, Builder); break; default: return false; diff --git a/llvm/test/CodeGen/Thumb2/mve-gather-increment.ll b/llvm/test/CodeGen/Thumb2/mve-gather-increment.ll --- a/llvm/test/CodeGen/Thumb2/mve-gather-increment.ll +++ b/llvm/test/CodeGen/Thumb2/mve-gather-increment.ll @@ -1410,24 +1410,22 @@ ; CHECK-NEXT: .LBB15_1: @ %vector.ph ; CHECK-NEXT: adr r3, .LCPI15_0 ; CHECK-NEXT: vldrw.u32 q0, [r3] -; CHECK-NEXT: vmov.i32 q1, #0x4 +; CHECK-NEXT: vadd.i32 q0, q0, r1 ; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB15_2: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vshl.i32 q2, q0, #2 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: vldrw.u32 q3, [r1, q2, uxtw #2] -; CHECK-NEXT: vstrw.32 q3, [r0], #16 +; CHECK-NEXT: vldrw.u32 q1, [q0, #64]! +; CHECK-NEXT: vstrw.32 q1, [r0], #16 ; CHECK-NEXT: letp lr, .LBB15_2 ; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup ; CHECK-NEXT: pop {r7, pc} ; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: @ %bb.4: ; CHECK-NEXT: .LCPI15_0: -; CHECK-NEXT: .long 0 @ 0x0 -; CHECK-NEXT: .long 1 @ 0x1 -; CHECK-NEXT: .long 2 @ 0x2 -; CHECK-NEXT: .long 3 @ 0x3 +; CHECK-NEXT: .long 4294967232 @ 0xffffffc0 +; CHECK-NEXT: .long 4294967248 @ 0xffffffd0 +; CHECK-NEXT: .long 4294967264 @ 0xffffffe0 +; CHECK-NEXT: .long 4294967280 @ 0xfffffff0 entry: %cmp6 = icmp sgt i32 %n, 0 br i1 %cmp6, label %vector.ph, label %for.cond.cleanup diff --git a/llvm/test/CodeGen/Thumb2/mve-scatter-increment.ll b/llvm/test/CodeGen/Thumb2/mve-scatter-increment.ll --- a/llvm/test/CodeGen/Thumb2/mve-scatter-increment.ll +++ b/llvm/test/CodeGen/Thumb2/mve-scatter-increment.ll @@ -236,24 +236,22 @@ ; CHECK-NEXT: .LBB4_1: @ %vector.ph ; CHECK-NEXT: adr r3, .LCPI4_0 ; CHECK-NEXT: vldrw.u32 q0, [r3] -; CHECK-NEXT: vmov.i32 q1, #0x4 +; CHECK-NEXT: vadd.i32 q0, q0, r1 ; CHECK-NEXT: dlstp.32 lr, r2 ; CHECK-NEXT: .LBB4_2: @ %vector.body ; CHECK-NEXT: @ =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vshl.i32 q3, q0, #2 -; CHECK-NEXT: vadd.i32 q0, q0, q1 -; CHECK-NEXT: vldrw.u32 q2, [r0], #16 -; CHECK-NEXT: vstrw.32 q2, [r1, q3, uxtw #2] +; CHECK-NEXT: vldrw.u32 q1, [r0], #16 +; CHECK-NEXT: vstrw.32 q1, [q0, #64]! ; CHECK-NEXT: letp lr, .LBB4_2 ; CHECK-NEXT: @ %bb.3: @ %for.cond.cleanup ; CHECK-NEXT: pop {r7, pc} ; CHECK-NEXT: .p2align 4 ; CHECK-NEXT: @ %bb.4: ; CHECK-NEXT: .LCPI4_0: -; CHECK-NEXT: .long 0 @ 0x0 -; CHECK-NEXT: .long 1 @ 0x1 -; CHECK-NEXT: .long 2 @ 0x2 -; CHECK-NEXT: .long 3 @ 0x3 +; CHECK-NEXT: .long 4294967232 @ 0xffffffc0 +; CHECK-NEXT: .long 4294967248 @ 0xffffffd0 +; CHECK-NEXT: .long 4294967264 @ 0xffffffe0 +; CHECK-NEXT: .long 4294967280 @ 0xfffffff0 entry: %cmp6 = icmp sgt i32 %n, 0 br i1 %cmp6, label %vector.ph, label %for.cond.cleanup