diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -261,6 +261,7 @@ bool X86LowerAMXType::visit() { SmallVector DeadInsts; + SmallVector DeadBitcasts; for (BasicBlock *BB : post_order(&Func)) { for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); @@ -272,6 +273,23 @@ Value *Src = Bitcast->getOperand(0); Type *Ty = Bitcast->getType(); + auto CanonicalizeBitcast = [&]() { + if (Bitcast->user_empty()) { + DeadBitcasts.push_back(Bitcast); + return true; + } + bool Change = false; + Value *DstV = Src, *PreDst = Bitcast, *SrcV; + while (match(DstV, m_BitCast(m_Value(SrcV))) && + SrcV->getType()->getTypeID() == PreDst->getType()->getTypeID()) { + PreDst->replaceAllUsesWith(SrcV); + DeadBitcasts.push_back(cast(PreDst)); + PreDst = DstV; + DstV = SrcV; + Change = true; + } + return Change; + }; if (Ty->isPointerTy() && cast(Ty)->getElementType()->isX86_AMXTy()) { @@ -302,10 +320,8 @@ // store <256 x i32> %2, <256 x i32>* %0, align 64 } if (Bitcast->getType()->isX86_AMXTy()) { - if (Bitcast->user_empty()) { - DeadInsts.push_back(Bitcast); + if (CanonicalizeBitcast()) continue; - } LoadInst *LD = dyn_cast(Src); if (!LD) { if (transformBitcast(Bitcast)) @@ -333,10 +349,8 @@ if (LD->hasOneUse()) DeadInsts.push_back(LD); } else if (Src->getType()->isX86_AMXTy()) { - if (Bitcast->user_empty()) { - DeadInsts.push_back(Bitcast); + if (CanonicalizeBitcast()) continue; - } StoreInst *ST = nullptr; for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); UI != UE;) { @@ -378,7 +392,7 @@ } } - bool C = !DeadInsts.empty(); + bool C = !DeadInsts.empty() || !DeadBitcasts.empty(); SmallSet DeletedInst; auto DeleteInst = [&](Instruction *Inst) { @@ -401,6 +415,8 @@ }; for (auto *Inst : DeadInsts) DeleteInst(Inst); + for (auto *Inst : DeadBitcasts) + DeleteInst(Inst); return C; } diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -8,6 +8,17 @@ @buf = dso_local global [1024 x i8] zeroinitializer, align 16 @buf2 = dso_local global [1024 x i8] zeroinitializer, align 16 +define dso_local <256 x i32> @test_amx_bitcast(<256 x i32> %in) #2 { +; CHECK-LABEL: @test_amx_bitcast( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret <256 x i32> [[IN:%.*]] +; +entry: + %amx = bitcast <256 x i32> %in to x86_amx + %vec = bitcast x86_amx %amx to <256 x i32> + ret <256 x i32> %vec +} + define dso_local void @test_amx_store(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { ; CHECK-LABEL: @test_amx_store( ; CHECK-NEXT: entry: