Index: lib/Target/ARM/ARMParallelDSP.cpp =================================================================== --- lib/Target/ARM/ARMParallelDSP.cpp +++ lib/Target/ARM/ARMParallelDSP.cpp @@ -37,6 +37,9 @@ #define DEBUG_TYPE "arm-parallel-dsp" STATISTIC(NumSMLAD , "Number of smlad instructions generated"); +STATISTIC(NumSMLADX, "Number of smladx instructions generated"); +STATISTIC(NumSMUAD, "Number of smuad instructions generated"); +STATISTIC(NumSMUADX, "Number of smuadx instructions generated"); static cl::opt DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), @@ -122,6 +125,7 @@ /// Return the add instruction which is the root of the reduction. Instruction *getRoot() { return Root; } + /// Return whether this reduction accumulating into a 64-bit result. bool is64Bit() const { return Root->getType()->isIntegerTy(64); } /// Return the incoming value to be accumulated. This maybe null. @@ -600,43 +604,57 @@ void ARMParallelDSP::InsertParallelMACs(Reduction &R) { - auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1, - Value *Acc, bool Exchange, - Instruction *InsertAfter) { + auto CreateSMLAD = [&](LoadInst *LHS, LoadInst *RHS, Value *Acc, + bool Exchange, Instruction *InsertAfter) { // Replace the reduction chain with an intrinsic call - Value* Args[] = { WideLd0, WideLd1, Acc }; + IRBuilder Builder(InsertAfter->getParent(), + ++BasicBlock::iterator(InsertAfter)); + Value* Args[] = { LHS, RHS, Acc }; Function *SMLAD = nullptr; - if (Exchange) + if (Exchange) { SMLAD = Acc->getType()->isIntegerTy(32) ? Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) : Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx); - else + ++NumSMLADX; + } else { SMLAD = Acc->getType()->isIntegerTy(32) ? Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) : Intrinsic::getDeclaration(M, Intrinsic::arm_smlald); + ++NumSMLAD; + } + return Builder.CreateCall(SMLAD, Args); + }; + auto CreateSMUAD = [&](LoadInst *LHS, LoadInst *RHS, bool Exchange, + Instruction *InsertAfter) { IRBuilder Builder(InsertAfter->getParent(), ++BasicBlock::iterator(InsertAfter)); - Instruction *Call = Builder.CreateCall(SMLAD, Args); - NumSMLAD++; - return Call; + Value* Args[] = { LHS, RHS }; + Function *SMUAD = Exchange ? + Intrinsic::getDeclaration(M, Intrinsic::arm_smuadx) : + Intrinsic::getDeclaration(M, Intrinsic::arm_smuad); + + if (Exchange) + ++NumSMUADX; + else + ++NumSMUAD; + + return Builder.CreateCall(SMUAD, Args); }; Instruction *InsertAfter = R.getRoot(); Value *Acc = R.getAccumulator(); - if (!Acc) - Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0); - + LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"); IntegerType *Ty = IntegerType::get(M->getContext(), 32); - LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n" - << "Acc: " << *Acc << "\n"); + for (auto &Pair : R.getMulPairs()) { MulCandidate *LHSMul = Pair.first; MulCandidate *RHSMul = Pair.second; LLVM_DEBUG(dbgs() << "Muls:\n" << "- " << *LHSMul->Root << "\n" << "- " << *RHSMul->Root << "\n"); + LoadInst *BaseLHS = LHSMul->getBaseLoad(); LoadInst *BaseRHS = RHSMul->getBaseLoad(); LoadInst *WideLHS = WideLoads.count(BaseLHS) ? @@ -644,7 +662,13 @@ LoadInst *WideRHS = WideLoads.count(BaseRHS) ? WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty); - Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); + if (!Acc && !R.is64Bit()) + Acc = CreateSMUAD(WideLHS, WideRHS, RHSMul->Exchange, InsertAfter); + else { + if (!Acc) + Acc = ConstantInt::get(R.getRoot()->getType(), 0); + Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); + } InsertAfter = cast(Acc); } R.UpdateRoot(cast(Acc)); Index: test/CodeGen/ARM/ParallelDSP/blocks.ll =================================================================== --- test/CodeGen/ARM/ParallelDSP/blocks.ll +++ test/CodeGen/ARM/ParallelDSP/blocks.ll @@ -25,13 +25,54 @@ ret i32 %res } -; CHECK-LABEL: multi_block +; TODO: We should generate a smuad as the first intrinsic. The reason why we +; don't is because the reduction searches until no more adds are found and +; we only query the type of the final result. +; CHECK-LABEL: smuad_smlald +; CHECK: call i64 @llvm.arm.smlald +; CHECK: call i64 @llvm.arm.smlald +define i64 @smuad_smlald(i16* %a, i16* %b) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %mul.0 = mul i32 %sext.a.0, %sext.b.0 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i32 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i32 + %mul.1 = mul i32 %sext.a.1, %sext.b.1 + %add = add i32 %mul.0, %mul.1 + %acc = sext i32 %add to i64 + %addr.a.2 = getelementptr i16, i16* %a, i32 2 + %addr.b.2 = getelementptr i16, i16* %b, i32 2 + %ld.a.2 = load i16, i16* %addr.a.2 + %sext.a.2 = sext i16 %ld.a.2 to i64 + %ld.b.2 = load i16, i16* %addr.b.2 + %sext.b.2 = sext i16 %ld.b.2 to i64 + %mul.2 = mul i64 %sext.a.2, %sext.b.2 + %addr.a.3 = getelementptr i16, i16* %a, i32 3 + %addr.b.3 = getelementptr i16, i16* %b, i32 3 + %ld.a.3 = load i16, i16* %addr.a.3 + %sext.a.3 = sext i16 %ld.a.3 to i64 + %ld.b.3 = load i16, i16* %addr.b.3 + %sext.b.3 = sext i16 %ld.b.3 to i64 + %mul.3 = mul i64 %sext.a.3, %sext.b.3 + %add.1 = add i64 %mul.2, %mul.3 + %res = add i64 %add.1, %acc + ret i64 %res +} + +; CHECK-LABEL: multi_block_smuad ; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* ; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] ; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* ; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] -; CHECK call i32 @llvm.arm.smlad(i32 [[A]], i32 [[B]], i32 0) -define i32 @multi_block(i16* %a, i16* %b, i32 %acc) { +; CHECK call i32 @llvm.arm.smuad(i32 [[A]], i32 [[B]]) +define i32 @multi_block_smuad(i16* %a, i16* %b, i32 %acc) { entry: %ld.a.0 = load i16, i16* %a %sext.a.0 = sext i16 %ld.a.0 to i32 @@ -53,9 +94,106 @@ ret i32 %res } -; CHECK-LABEL: multi_block_1 -; CHECK-NOT: call i32 @llvm.arm.smlad -define i32 @multi_block_1(i16* %a, i16* %b, i32 %acc) { +; CHECK-LABEL: single_block_smlald +; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* +; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] +; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* +; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] +; CHECK call i32 @llvm.arm.smlald(i32 [[A]], i32 [[B]], i32 0) +define i64 @single_block_smlald(i16* %a, i16* %b) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i64 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i64 + %mul.0 = mul i64 %sext.a.0, %sext.b.0 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i64 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i64 + %mul.1 = mul i64 %sext.a.1, %sext.b.1 + %add = add i64 %mul.0, %mul.1 + ret i64 %add +} + +; CHECK-LABEL: single_block_smlald_32_64 +; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* +; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] +; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* +; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] +; CHECK call i32 @llvm.arm.smlald(i32 [[A]], i32 [[B]], i32 0) +define i64 @single_block_smlald_32_64(i16* %a, i16* %b) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %mul.0 = mul i32 %sext.a.0, %sext.b.0 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i64 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i64 + %mul.1 = mul i64 %sext.a.1, %sext.b.1 + %conv = sext i32 %mul.0 to i64 + %add = add i64 %conv, %mul.1 + ret i64 %add +} + +; CHECK-LABEL: single_block_smuadx +; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* +; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] +; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* +; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] +; CHECK call i32 @llvm.arm.smuadx(i32 [[A]], i32 [[B]]) +define i32 @single_block_smuadx(i16* %a, i16* %b) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i32 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i32 + %mul.0 = mul i32 %sext.a.0, %sext.b.1 + %mul.1 = mul i32 %sext.a.1, %sext.b.0 + %add = add i32 %mul.0, %mul.1 + ret i32 %add +} + +; CHECK-LABEL: multi_block_smuadx_1 +; CHECK: [[CAST_A:%[^ ]+]] = bitcast i16* %a to i32* +; CHECK: [[A:%[^ ]+]] = load i32, i32* [[CAST_A]] +; CHECK: [[CAST_B:%[^ ]+]] = bitcast i16* %b to i32* +; CHECK: [[B:%[^ ]+]] = load i32, i32* [[CAST_B]] +; CHECK call i32 @llvm.arm.smuadx(i32 [[B]], i32 [[A]]) +define i32 @multi_block_smuadx_1(i16* %a, i16* %b) { +entry: + %ld.a.0 = load i16, i16* %a + %sext.a.0 = sext i16 %ld.a.0 to i32 + %ld.b.0 = load i16, i16* %b + %sext.b.0 = sext i16 %ld.b.0 to i32 + %addr.a.1 = getelementptr i16, i16* %a, i32 1 + %ld.a.1 = load i16, i16* %addr.a.1 + %sext.a.1 = sext i16 %ld.a.1 to i32 + %mul.0 = mul i32 %sext.a.1, %sext.b.0 + %addr.b.1 = getelementptr i16, i16* %b, i32 1 + %ld.b.1 = load i16, i16* %addr.b.1 + %sext.b.1 = sext i16 %ld.b.1 to i32 + %mul.1 = mul i32 %sext.a.0, %sext.b.1 + %add = add i32 %mul.0, %mul.1 + ret i32 %add +} + +; CHECK-LABEL: multi_block_fail +; CHECK-NOT: call i32 @llvm.arm.sm{{.*}}ad +define i32 @multi_block_fail(i16* %a, i16* %b, i32 %acc) { entry: %ld.a.0 = load i16, i16* %a %sext.a.0 = sext i16 %ld.a.0 to i32 Index: test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll =================================================================== --- test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll +++ test/CodeGen/ARM/ParallelDSP/inner-full-unroll.ll @@ -17,8 +17,8 @@ ; CHECK: [[CIJ_2:%[^ ]+]] = getelementptr inbounds i16, i16* [[CIJ]], i32 2 ; CHECK: [[CIJ_2_CAST:%[^ ]+]] = bitcast i16* [[CIJ_2]] to i32* ; CHECK: [[CIJ_2_LD:%[^ ]+]] = load i32, i32* [[CIJ_2_CAST]], align 2 -; CHECK: [[SMLAD0:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_2_LD]], i32 [[BIJ_2_LD]], i32 0) -; CHECK: [[SMLAD1:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_LD]], i32 [[BIJ_LD]], i32 [[SMLAD0]]) +; CHECK: [[SMUAD:%[^ ]+]] = call i32 @llvm.arm.smuad(i32 [[CIJ_2_LD]], i32 [[BIJ_2_LD]]) +; CHECK: [[SMLAD1:%[^ ]+]] = call i32 @llvm.arm.smlad(i32 [[CIJ_LD]], i32 [[BIJ_LD]], i32 [[SMUAD]]) ; CHECK: store i32 [[SMLAD1]], i32* %arrayidx, align 4 define void @full_unroll(i32* noalias nocapture %a, i16** noalias nocapture readonly %b, i16** noalias nocapture readonly %c, i32 %N) {