diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -589,7 +589,16 @@ // Fold away bit casts of the loaded value by loading the desired type. // Note that we should not do this for pointer<->integer casts, // because that would result in type punning. - if (LI.hasOneUse()) + if (LI.hasOneUse()) { + // Don't transform when the type is x86_amx, it makes the pass that lower + // x86_amx type happy. + if (auto *BC = dyn_cast(LI.user_back())) { + assert(!LI.getType()->isX86_AMXTy() && + "load from x86_amx* should not happen!"); + if (BC->getType()->isX86_AMXTy()) + return nullptr; + } + if (auto* CI = dyn_cast(LI.user_back())) if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() == CI->getDestTy()->isPtrOrPtrVectorTy()) @@ -599,6 +608,7 @@ IC.eraseInstFromFunction(*CI); return &LI; } + } // FIXME: We should also canonicalize loads of vectors when their elements are // cast to other types. @@ -1114,10 +1124,12 @@ // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast(V)) { + assert(!BC->getType()->isX86_AMXTy() && + "store to x86_amx* should not happen!"); V = BC->getOperand(0); - // Don't transform when the type is x86_amx, it make the pass that lower + // Don't transform when the type is x86_amx, it makes the pass that lower // x86_amx type happy. - if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy()) + if (V->getType()->isX86_AMXTy()) return false; if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { combineStoreToNewValue(IC, SI, V); diff --git a/llvm/test/Transforms/InstCombine/X86/x86-amx-load-store.ll b/llvm/test/Transforms/InstCombine/X86/x86-amx-load-store.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/X86/x86-amx-load-store.ll @@ -0,0 +1,38 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -instcombine -S < %s | FileCheck %s +; RUN: opt -passes=instcombine -S < %s | FileCheck %s + +; Prohibit poiter cast for amx. +define dso_local void @test_amx_load_store(<256 x i32>* %src, i8* %dst) { +; CHECK-LABEL: @test_amx_load_store( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[VEC:%.*]] = load <256 x i32>, <256 x i32>* [[SRC:%.*]], align 64 +; CHECK-NEXT: [[BC:%.*]] = bitcast <256 x i32> [[VEC]] to x86_amx +; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* [[DST:%.*]], i64 64, x86_amx [[BC]]) +; CHECK-NEXT: ret void +; +entry: + %vec = load <256 x i32>, <256 x i32>* %src, align 64 + %bc = bitcast <256 x i32> %vec to x86_amx + tail call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %dst, i64 64, x86_amx %bc) + ret void +} + +; Prohibit poiter cast for amx. +define dso_local void @test_amx_load_store2(<256 x i32>* %dst, i8* %src) { +; CHECK-LABEL: @test_amx_load_store2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AMX:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, i8* [[SRC:%.*]], i64 64) +; CHECK-NEXT: [[BC:%.*]] = bitcast x86_amx [[AMX]] to <256 x i32> +; CHECK-NEXT: store <256 x i32> [[BC]], <256 x i32>* [[DST:%.*]], align 1024 +; CHECK-NEXT: ret void +; +entry: + %amx = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, i8* %src, i64 64) + %bc = bitcast x86_amx %amx to <256 x i32> + store <256 x i32> %bc, <256 x i32>* %dst + ret void +} + +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)