diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -589,7 +589,15 @@ // Fold away bit casts of the loaded value by loading the desired type. // Note that we should not do this for pointer<->integer casts, // because that would result in type punning. - if (LI.hasOneUse()) + if (LI.hasOneUse()) { + // Don't transform when the type is x86_amx, it make the pass that lower + // x86_amx type happy. + if (auto *BC = dyn_cast(LI.user_back())) { + Value *V = BC->getOperand(0); + if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy()) + return nullptr; + } + if (auto* CI = dyn_cast(LI.user_back())) if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() == CI->getDestTy()->isPtrOrPtrVectorTy()) @@ -599,6 +607,7 @@ IC.eraseInstFromFunction(*CI); return &LI; } + } // FIXME: We should also canonicalize loads of vectors when their elements are // cast to other types. diff --git a/llvm/test/Transforms/InstCombine/load.ll b/llvm/test/Transforms/InstCombine/load.ll --- a/llvm/test/Transforms/InstCombine/load.ll +++ b/llvm/test/Transforms/InstCombine/load.ll @@ -422,3 +422,19 @@ call void @use.v2.p1(<2 x i8 addrspace(1)*> %Y) ret <2 x i64> %X } + +define dso_local void @test_amx_load_store(<256 x i32>* %src, x86_amx* %dst) { +entry: + %vec = load <256 x i32>, <256 x i32>* %src, align 64 + %bc = bitcast <256 x i32> %vec to x86_amx + store x86_amx %bc, x86_amx* %dst + ret void +} + +define dso_local void @test_amx_load_store2(<256 x i32>* %dst, x86_amx* %src) { +entry: + %vec = load x86_amx, x86_amx* %src, align 64 + %bc = bitcast x86_amx %vec to <256 x i32> + store <256 x i32> %bc, <256 x i32>* %dst + ret void +}