Index: lib/Target/ARM/ARMParallelDSP.cpp =================================================================== --- lib/Target/ARM/ARMParallelDSP.cpp +++ lib/Target/ARM/ARMParallelDSP.cpp @@ -1,4 +1,4 @@ -//===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===// +//===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -18,13 +18,10 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopAccessAnalysis.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/NoFolder.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Pass.h" #include "llvm/PassRegistry.h" #include "llvm/PassSupport.h" @@ -176,13 +173,11 @@ } }; - class ARMParallelDSP : public LoopPass { + class ARMParallelDSP : public FunctionPass { ScalarEvolution *SE; AliasAnalysis *AA; TargetLibraryInfo *TLI; DominatorTree *DT; - LoopInfo *LI; - Loop *L; const DataLayout *DL; Module *M; std::map LoadPairs; @@ -204,63 +199,33 @@ /// products to a 32-bit accumulate operand. Optionally, the instruction can /// exchange the halfwords of the second operand before performing the /// arithmetic. - bool MatchSMLAD(Loop *L); + bool MatchSMLAD(Function &F); public: static char ID; - ARMParallelDSP() : LoopPass(ID) { } - - bool doInitialization(Loop *L, LPPassManager &LPM) override { - LoadPairs.clear(); - WideLoads.clear(); - return true; - } + ARMParallelDSP() : FunctionPass(ID) { } void getAnalysisUsage(AnalysisUsage &AU) const override { - LoopPass::getAnalysisUsage(AU); + FunctionPass::getAnalysisUsage(AU); AU.addRequired(); AU.addRequired(); AU.addRequired(); AU.addRequired(); - AU.addRequired(); AU.addRequired(); AU.addRequired(); - AU.addPreserved(); AU.setPreservesCFG(); } - bool runOnLoop(Loop *TheLoop, LPPassManager &) override { + bool runOnFunction(Function &F) override { if (DisableParallelDSP) return false; - if (skipLoop(TheLoop)) - return false; - - L = TheLoop; SE = &getAnalysis().getSE(); AA = &getAnalysis().getAAResults(); TLI = &getAnalysis().getTLI(); DT = &getAnalysis().getDomTree(); - LI = &getAnalysis().getLoopInfo(); auto &TPC = getAnalysis(); - BasicBlock *Header = TheLoop->getHeader(); - if (!Header) - return false; - - // TODO: We assume the loop header and latch to be the same block. - // This is not a fundamental restriction, but lifting this would just - // require more work to do the transformation and then patch up the CFG. - if (Header != TheLoop->getLoopLatch()) { - LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not " - "running pass ARMParallelDSP\n"); - return false; - } - - if (!TheLoop->getLoopPreheader()) - InsertPreheaderForLoop(L, DT, LI, nullptr, true); - - Function &F = *Header->getParent(); M = F.getParent(); DL = &M->getDataLayout(); @@ -285,17 +250,10 @@ return false; } - LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI); - LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n"); LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n"); - if (!RecordMemoryOps(Header)) { - LLVM_DEBUG(dbgs() << " - No sequential loads found.\n"); - return false; - } - - bool Changes = MatchSMLAD(L); + bool Changes = MatchSMLAD(F); return Changes; } }; @@ -378,6 +336,8 @@ bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { SmallVector Loads; SmallVector Writes; + LoadPairs.clear(); + WideLoads.clear(); // Collect loads and instruction that may write to memory. For now we only // record loads which are simple, sign-extended and have a single user. @@ -456,7 +416,7 @@ return LoadPairs.size() > 1; } -// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector +// The Pass that needs to identify integer add/sub reductions of 16-bit vector // multiplications. // To use SMLAD: // 1) we first need to find integer add then look for this pattern: @@ -484,16 +444,13 @@ // If constants are used instead of loads, these will need to be hoisted // out and into a register. // -// If loop invariants are used instead of loads, these need to be packed -// before the loop begins. -// -bool ARMParallelDSP::MatchSMLAD(Loop *L) { +bool ARMParallelDSP::MatchSMLAD(Function &F) { // Search recursively back through the operands to find a tree of values that // form a multiply-accumulate chain. The search records the Add and Mul // instructions that form the reduction and allows us to find a single value // to be used as the initial input to the accumlator. - std::function Search = [&] - (Value *V, Reduction &R) -> bool { + std::function Search = [&] + (Value *V, BasicBlock *BB, Reduction &R) -> bool { // If we find a non-instruction, try to use it as the initial accumulator // value. This may have already been found during the search in which case @@ -502,6 +459,9 @@ if (!I) return R.InsertAcc(V); + if (I->getParent() != BB) + return false; + switch (I->getOpcode()) { default: break; @@ -512,8 +472,8 @@ // Adds should be adding together two muls, or another add and a mul to // be within the mac chain. One of the operands may also be the // accumulator value at which point we should stop searching. - bool ValidLHS = Search(I->getOperand(0), R); - bool ValidRHS = Search(I->getOperand(1), R); + bool ValidLHS = Search(I->getOperand(0), BB, R); + bool ValidRHS = Search(I->getOperand(1), BB, R); if (!ValidLHS && !ValidLHS) return false; else if (ValidLHS && ValidRHS) { @@ -539,36 +499,41 @@ return false; } case Instruction::SExt: - return Search(I->getOperand(0), R); + return Search(I->getOperand(0), BB, R); } return false; }; bool Changed = false; - SmallPtrSet AllAdds; - BasicBlock *Latch = L->getLoopLatch(); - for (Instruction &I : reverse(*Latch)) { - if (I.getOpcode() != Instruction::Add) + for (auto &BB : F) { + SmallPtrSet AllAdds; + if (!RecordMemoryOps(&BB)) { + LLVM_DEBUG(dbgs() << " - No sequential loads found.\n"); continue; + } + for (Instruction &I : reverse(BB)) { + if (I.getOpcode() != Instruction::Add) + continue; - if (AllAdds.count(&I)) - continue; + if (AllAdds.count(&I)) + continue; - const auto *Ty = I.getType(); - if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) - continue; + const auto *Ty = I.getType(); + if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) + continue; - Reduction R(&I); - if (!Search(&I, R)) - continue; + Reduction R(&I); + if (!Search(&I, &BB, R)) + continue; - if (!CreateParallelPairs(R)) - continue; + if (!CreateParallelPairs(R)) + continue; - InsertParallelMACs(R); - Changed = true; - AllAdds.insert(R.getAdds().begin(), R.getAdds().end()); + InsertParallelMACs(R); + Changed = true; + AllAdds.insert(R.getAdds().begin(), R.getAdds().end()); + } } return Changed; @@ -845,6 +810,6 @@ char ARMParallelDSP::ID = 0; INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp", - "Transform loops to use DSP intrinsics", false, false) + "Transform functions to use DSP intrinsics", false, false) INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp", - "Transform loops to use DSP intrinsics", false, false) + "Transform functions to use DSP intrinsics", false, false) Index: test/CodeGen/ARM/O3-pipeline.ll =================================================================== --- test/CodeGen/ARM/O3-pipeline.ll +++ test/CodeGen/ARM/O3-pipeline.ll @@ -37,8 +37,7 @@ ; CHECK-NEXT: Scalar Evolution Analysis ; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) ; CHECK-NEXT: Function Alias Analysis Results -; CHECK-NEXT: Loop Pass Manager -; CHECK-NEXT: Transform loops to use DSP intrinsics +; CHECK-NEXT: Transform functions to use DSP intrinsics ; CHECK-NEXT: Interleaved Access Pass ; CHECK-NEXT: ARM IR optimizations ; CHECK-NEXT: Dominator Tree Construction Index: test/CodeGen/ARM/ParallelDSP/blocks.ll =================================================================== --- /dev/null +++ test/CodeGen/ARM/ParallelDSP/blocks.ll @@ -0,0 +1,79 @@ +; RUN: opt -arm-parallel-dsp -mtriple=armv7-a -S %s -o - | FileCheck %s + +; CHECK-LABEL: single_block +; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* +; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] +; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* +; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] +; CHECK call i32 @llvm.arm.smlad(i32 [[A]], i32 [[B]], i32 %acc) +define i32 @single_block(i16* %a, i16* %b, i32 %acc) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %mul.0 = mul i32 %sext.a.0, %sext.b.0 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i32 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i32 + %mul.1 = mul i32 %sext.a.1, %sext.b.1 + %add = add i32 %mul.0, %mul.1 + %res = add i32 %add, %acc + ret i32 %res +} + +; CHECK-LABEL: multi_block +; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* +; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] +; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* +; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] +; CHECK call i32 @llvm.arm.smlad(i32 [[A]], i32 [[B]], i32 0) +define i32 @multi_block(i16* %a, i16* %b, i32 %acc) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %mul.0 = mul i32 %sext.a.0, %sext.b.0 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i32 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i32 + %mul.1 = mul i32 %sext.a.1, %sext.b.1 + %add = add i32 %mul.0, %mul.1 + br label %bb.1 + +bb.1: + %res = add i32 %add, %acc + ret i32 %res +} + +; CHECK-LABEL: multi_block_1 +; CHECK-NOT: call i32 @llvm.arm.smlad +define i32 @multi_block_1(i16* %a, i16* %b, i32 %acc) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %mul.0 = mul i32 %sext.a.0, %sext.b.0 + br label %bb.1 + +bb.1: + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i32 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i32 + %mul.1 = mul i32 %sext.a.1, %sext.b.1 + %add = add i32 %mul.0, %mul.1 + %res = add i32 %add, %acc + ret i32 %res +} + Index: test/CodeGen/ARM/ParallelDSP/smlad12.ll =================================================================== --- test/CodeGen/ARM/ParallelDSP/smlad12.ll +++ test/CodeGen/ARM/ParallelDSP/smlad12.ll @@ -2,7 +2,7 @@ ; ; The loop header is not the loop latch. ; -; CHECK-NOT: call i32 @llvm.arm.smlad +; CHECK: call i32 @llvm.arm.smlad ; define dso_local i32 @test(i32 %arg, i32* nocapture readnone %arg1, i16* nocapture readonly %arg2, i16* nocapture readonly %arg3) { entry: