Index: lib/Target/AArch64/AArch64A57FPLoadBalancing.cpp =================================================================== --- lib/Target/AArch64/AArch64A57FPLoadBalancing.cpp +++ lib/Target/AArch64/AArch64A57FPLoadBalancing.cpp @@ -66,6 +66,18 @@ //===----------------------------------------------------------------------===// // Helper functions +// Is the instruction an AESE or AESD? +static bool isAESEnDe(MachineInstr *MI) { + return (MI->getOpcode() == AArch64::AESErr || + MI->getOpcode() == AArch64::AESDrr); +} + +// Is the instruction an AESMC or AESIMC? +static bool isAESMix(MachineInstr *MI) { + return (MI->getOpcode() == AArch64::AESMCrr || + MI->getOpcode() == AArch64::AESIMCrr); +} + // Is the instruction a type of multiply on 64-bit (or 32-bit) FPRs? static bool isMul(MachineInstr *MI) { switch (MI->getOpcode()) { @@ -550,10 +562,16 @@ MachineOperand &MO = I->getOperand(0); bool Change = TransformAll || getColor(MO.getReg()) != C; + bool isAES = isAESEnDe(I) || isAESMix(I); + Change = Change || isAES; if (G->requiresFixup() && &*I == G->getLast()) Change = false; if (Change) { + if (isAES && I->getOperand(1).isKill()) + // Keep the same accumulation register for AES chains + Reg = I->getOperand(1).getReg(); + Substs[MO.getReg()] = Reg; MO.setReg(Reg); MRI->setPhysRegUsed(Reg); @@ -642,9 +660,69 @@ ActiveChains[DestReg] = G.get(); AllChains.insert(std::move(G)); + } else if (isAESEnDe(MI)) { + // AESE and AESD are executed by FMA functional units and the Dest register + // is the Accum register, treat them as MLAs. + unsigned DestReg = MI->getOperand(0).getReg(); + unsigned SrcReg = MI->getOperand(2).getReg(); + + if (DestReg != SrcReg) + maybeKillChain(MI->getOperand(2), Idx, ActiveChains); + + if (ActiveChains.find(DestReg) != ActiveChains.end()) { + DEBUG(dbgs() << "Chain found for AESE/AESD dest register " + << TRI->getName(DestReg) << " in MI " << *MI); + + // DestReg is the AccumReg, so no need to check if it's killed. + DEBUG(dbgs() << "Instruction was successfully added to chain.\n"); + ActiveChains[DestReg]->add(MI, Idx, getColor(DestReg)); + return; + } + + // Create a new chain for DestReg + maybeKillChain(MI->getOperand(0), Idx, ActiveChains); + DEBUG(dbgs() << "Creating new chain for AESE/AESD dest register " + << TRI->getName(DestReg) << " at " << *MI); + auto G = llvm::make_unique(MI, Idx, getColor(DestReg)); + ActiveChains[DestReg] = G.get(); + AllChains.insert(std::move(G)); + + } else if (isAESMix(MI)) { + // AESMC and AESIMC + unsigned DestReg = MI->getOperand(0).getReg(); + unsigned SrcReg = MI->getOperand(1).getReg(); + + if (DestReg != SrcReg) + maybeKillChain(MI->getOperand(0), Idx, ActiveChains); + + if (ActiveChains.find(SrcReg) != ActiveChains.end()) { + DEBUG(dbgs() << "Chain found for AESMC/AESIMC src register " + << TRI->getName(SrcReg) << " in MI " << *MI); + + DEBUG(dbgs() << "Instruction was successfully added to chain.\n"); + ActiveChains[SrcReg]->add(MI, Idx, getColor(SrcReg)); + // Handle cases where the destination is not the same as the accumulator. + if (DestReg != SrcReg) { + DEBUG(dbgs() << "Transfer chain onwership from " + << TRI->getName(SrcReg) << " to " + << TRI->getName(DestReg) << "\n"); + ActiveChains[DestReg] = ActiveChains[SrcReg]; + ActiveChains.erase(SrcReg); + } + return; + } + + // Create a new chain for SrcReg + maybeKillChain(MI->getOperand(0), Idx, ActiveChains); + DEBUG(dbgs() << "Creating new chain for AESMC/AEIMC dest register " + << TRI->getName(DestReg) << " at " << *MI); + auto G = llvm::make_unique(MI, Idx, getColor(DestReg)); + ActiveChains[DestReg] = G.get(); + AllChains.insert(std::move(G)); + } else { - // Non-MUL or MLA instruction. Invalidate any chain in the uses or defs + // Not MUL, MLA or AES instruction. Invalidate any chain in the uses or defs // lists. for (auto &I : MI->uses()) maybeKillChain(I, Idx, ActiveChains); Index: test/CodeGen/AArch64/aes-load-balancing.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/aes-load-balancing.ll @@ -0,0 +1,28 @@ +; RUN: llc < %s -mcpu=cortex-a57 | FileCheck %s +; RUN: llc < %s -mcpu=cortex-a53 | FileCheck %s + +target triple = "aarch64--linux-gnu" + +declare <16 x i8> @llvm.aarch64.crypto.aese(<16 x i8>, <16 x i8>) +declare <16 x i8> @llvm.aarch64.crypto.aesd(<16 x i8>, <16 x i8>) +declare <16 x i8> @llvm.aarch64.crypto.aesmc(<16 x i8>) +declare <16 x i8> @llvm.aarch64.crypto.aesimc(<16 x i8>) + +; Check that we use the same accumulation register for mixed AES instructions. +define i32 @aes_load_balancing(<16 x i8>* %x, <16 x i8>* %y, <16 x i8>* %z) { +;CHECK-LABEL: aes_load_balancing: +;CHECK: aese v0.16b, v{{[0-9]}}.16b +;CHECK: aesmc v0.16b, v0.16b +;CHECK: aesd v0.16b, v{{[0-9]}}.16b +;CHECK: aesimc v0.16b, v0.16b +entry: + %0 = load <16 x i8>* %x, align 16 + %1 = load <16 x i8>* %y, align 16 + %2 = load <16 x i8>* %z, align 16 + %3 = tail call <16 x i8> @llvm.aarch64.crypto.aese(<16 x i8> %0, <16 x i8> %1) + %4 = tail call <16 x i8> @llvm.aarch64.crypto.aesmc(<16 x i8> %3) + %5 = tail call <16 x i8> @llvm.aarch64.crypto.aesd(<16 x i8> %4, <16 x i8> %2) + %6 = tail call <16 x i8> @llvm.aarch64.crypto.aesimc(<16 x i8> %5) + store <16 x i8> %6, <16 x i8>* %x, align 16 + ret i32 0 +}