Index: lib/Target/AMDGPU/SIPeepholeSDWA.cpp =================================================================== --- lib/Target/AMDGPU/SIPeepholeSDWA.cpp +++ lib/Target/AMDGPU/SIPeepholeSDWA.cpp @@ -30,6 +30,7 @@ #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include +#include using namespace llvm; @@ -49,14 +50,14 @@ const SIRegisterInfo *TRI; const SIInstrInfo *TII; - std::unordered_map> SDWAOperands; + std::unordered_map> SDWAOperands; Optional foldToImm(const MachineOperand &Op) const; public: static char ID; - typedef SmallVector, 4> SDWAOperandsVector; + typedef SmallVector, 4> SDWAOperandsVector; SIPeepholeSDWA() : MachineFunctionPass(ID) { initializeSIPeepholeSDWAPass(*PassRegistry::getPassRegistry()); @@ -427,13 +428,13 @@ break; if (Opcode == AMDGPU::V_LSHLREV_B32_e32) { - auto SDWADst = make_unique( + auto SDWADst = std::make_shared( Dst, Src1, *Imm == 16 ? WORD_1 : BYTE_3, UNUSED_PAD); DEBUG(dbgs() << "Match: " << MI << "To: " << *SDWADst << '\n'); SDWAOperands[&MI] = std::move(SDWADst); ++NumSDWAPatternsFound; } else { - auto SDWASrc = make_unique( + auto SDWASrc = std::make_shared( Src1, Dst, *Imm == 16 ? WORD_1 : BYTE_3, false, false, Opcode == AMDGPU::V_LSHRREV_B32_e32 ? false : true); DEBUG(dbgs() << "Match: " << MI << "To: " << *SDWASrc << '\n'); @@ -468,12 +469,12 @@ if (Opcode == AMDGPU::V_LSHLREV_B16_e32) { auto SDWADst = - make_unique(Dst, Src1, BYTE_1, UNUSED_PAD); + std::make_shared(Dst, Src1, BYTE_1, UNUSED_PAD); DEBUG(dbgs() << "Match: " << MI << "To: " << *SDWADst << '\n'); SDWAOperands[&MI] = std::move(SDWADst); ++NumSDWAPatternsFound; } else { - auto SDWASrc = make_unique( + auto SDWASrc = std::make_shared( Src1, Dst, BYTE_1, false, false, Opcode == AMDGPU::V_LSHRREV_B16_e32 ? false : true); DEBUG(dbgs() << "Match: " << MI << "To: " << *SDWASrc << '\n'); @@ -535,7 +536,7 @@ TRI->isPhysicalRegister(Dst->getReg())) break; - auto SDWASrc = make_unique( + auto SDWASrc = std::make_shared( Src0, Dst, SrcSel, false, false, Opcode == AMDGPU::V_BFE_U32 ? false : true); DEBUG(dbgs() << "Match: " << MI << "To: " << *SDWASrc << '\n'); @@ -563,7 +564,7 @@ TRI->isPhysicalRegister(Dst->getReg())) break; - auto SDWASrc = make_unique( + auto SDWASrc = std::make_shared( Src1, Dst, *Imm == 0x0000ffff ? WORD_0 : BYTE_0); DEBUG(dbgs() << "Match: " << MI << "To: " << *SDWASrc << '\n'); SDWAOperands[&MI] = std::move(SDWASrc); @@ -691,15 +692,52 @@ TRI = ST.getRegisterInfo(); TII = ST.getInstrInfo(); - std::unordered_map PotentialMatches; - + + // find all SDWA operands in MF matchSDWAOperands(MF); - for (auto &OperandPair : SDWAOperands) { - auto &Operand = OperandPair.second; + std::unordered_map PotentialMatches; + + // There should be no intesection between SDWA operands and potential MIs + // e.g.: + // v_and_b32 v0, 0xff, v1 -> src:v1 sel:BYTE_0 + // v_and_b32 v2, 0xff, v0 -> src:v0 sel:BYTE_0 + // v_add_u32 v3, v4, v2 + // + // In that example it is possible that we would fold 2nd instruction into 3rd + // (v_add_u32_sdwa) and then try to fold 1st instruction into 2nd (that was + // already destroyed) + // We keep track of every SDWA operand that should be "pulled out" - operand + // created from MI matched by another SDWA operand + std::unordered_set> PulledOut; + + for (const auto &OperandPair : SDWAOperands) { + const auto &Operand = OperandPair.second; MachineInstr *PotentialMI = Operand->potentialToConvert(TII); if (PotentialMI) { - PotentialMatches[PotentialMI].push_back(std::move(Operand)); + if (SDWAOperands.count(PotentialMI) > 0) { + PulledOut.insert(SDWAOperands[PotentialMI]); + } + PotentialMatches[PotentialMI].push_back(Operand); + } + } + + // Remove all potential matches that should be pulled out + for (auto PotentialIt = PotentialMatches.begin(); + PotentialIt != PotentialMatches.end(); ) { + auto &Ops = PotentialIt->second; + for (auto OpsIt = Ops.begin(); OpsIt != Ops.end(); ) { + if (PulledOut.count(*OpsIt) > 0) { + OpsIt = Ops.erase(OpsIt); + } else { + ++OpsIt; + } + } + + if (Ops.empty()) { + PotentialIt = PotentialMatches.erase(PotentialIt); + } else { + ++PotentialIt; } } Index: test/CodeGen/AMDGPU/sdwa-peephole.ll =================================================================== --- test/CodeGen/AMDGPU/sdwa-peephole.ll +++ test/CodeGen/AMDGPU/sdwa-peephole.ll @@ -393,3 +393,53 @@ store <2 x i16> %add, <2 x i16> addrspace(1)* %out, align 4 ret void } + + +; Check that "pulling out" SDWA operands works correctly. +; GCN-LABEL: {{^}}pulled_out_test: +; NOSDWA-DAG: v_and_b32_e32 v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} +; NOSDWA-DAG: v_lshlrev_b16_e32 v{{[0-9]+}}, 8, v{{[0-9]+}} +; NOSDWA-DAG: v_and_b32_e32 v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} +; NOSDWA-DAG: v_lshlrev_b16_e32 v{{[0-9]+}}, 8, v{{[0-9]+}} +; NOSDWA: v_or_b32_e32 v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} +; NOSDWA-NOT: v_and_b32_sdwa +; NOSDWA-NOT: v_or_b32_sdwa + +; SDWA-DAG: v_and_b32_sdwa v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 +; SDWA-DAG: v_lshlrev_b16_e32 v{{[0-9]+}}, 8, v{{[0-9]+}} +; SDWA-DAG: v_and_b32_sdwa v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1 +; SDWA-DAG: v_lshlrev_b16_e32 v{{[0-9]+}}, 8, v{{[0-9]+}} +; SDWA: v_or_b32_sdwa v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} dst_sel:WORD_1 dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:DWORD + +define amdgpu_kernel void @pulled_out_test(<8 x i8> addrspace(1)* %sourceA, <8 x i8> addrspace(1)* %destValues) { +entry: + %idxprom = ashr exact i64 15, 32 + %arrayidx = getelementptr inbounds <8 x i8>, <8 x i8> addrspace(1)* %sourceA, i64 %idxprom + %0 = load <8 x i8>, <8 x i8> addrspace(1)* %arrayidx, align 8 + + %1 = extractelement <8 x i8> %0, i32 0 + %2 = extractelement <8 x i8> %0, i32 1 + %3 = extractelement <8 x i8> %0, i32 2 + %4 = extractelement <8 x i8> %0, i32 3 + %5 = extractelement <8 x i8> %0, i32 4 + %6 = extractelement <8 x i8> %0, i32 5 + %7 = extractelement <8 x i8> %0, i32 6 + %8 = extractelement <8 x i8> %0, i32 7 + + %9 = insertelement <2 x i8> undef, i8 %1, i32 0 + %10 = insertelement <2 x i8> %9, i8 %2, i32 1 + %11 = insertelement <2 x i8> undef, i8 %3, i32 0 + %12 = insertelement <2 x i8> %11, i8 %4, i32 1 + %13 = insertelement <2 x i8> undef, i8 %5, i32 0 + %14 = insertelement <2 x i8> %13, i8 %6, i32 1 + %15 = insertelement <2 x i8> undef, i8 %7, i32 0 + %16 = insertelement <2 x i8> %15, i8 %8, i32 1 + + %17 = shufflevector <2 x i8> %10, <2 x i8> %12, <4 x i32> + %18 = shufflevector <2 x i8> %14, <2 x i8> %16, <4 x i32> + %19 = shufflevector <4 x i8> %17, <4 x i8> %18, <8 x i32> + + %arrayidx5 = getelementptr inbounds <8 x i8>, <8 x i8> addrspace(1)* %destValues, i64 %idxprom + store <8 x i8> %19, <8 x i8> addrspace(1)* %arrayidx5, align 8 + ret void +}