diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -4601,7 +4601,8 @@ if (auto *MD = dyn_cast(V)) visitMetadataAsValue(*MD, Call.getCaller()); if (auto *Const = dyn_cast(V)) - Assert(!Const->getType()->isX86_AMXTy(), + Assert(!(Const->getType()->isX86_AMXTy() && !Const->isNullValue() && + !isa(V)), "const x86_amx is not allowed in argument!"); } diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -305,6 +305,23 @@ for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); II != IE;) { Instruction &Inst = *II++; + // handle with amx_type of undef and zeroinitializer + if (auto *Intr = dyn_cast(&Inst)) { + for (unsigned I = 1, E = Intr->getNumOperands(); I != E; ++I) { + Value *OpV = Intr->getOperand(I); + if (OpV->getType()->isX86_AMXTy() && + (isa(OpV) || + (isa(OpV) && cast(OpV)->isNullValue()))) { + IRBuilder<> Builder(Intr); + Value *Row, *Col; + std::tie(Row, Col) = getShape(Intr, I); + std::array Args = {Row, Col}; + Intr->setOperand( + I, Builder.CreateIntrinsic(Intrinsic::x86_tilezero_internal, + None, Args)); + } + } + } auto *Bitcast = dyn_cast(&Inst); if (!Bitcast) continue; diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -303,6 +303,32 @@ ret void } +define dso_local void @__tile_dpbf16ps_with_undef_amxtype(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { +; CHECK-LABEL: @__tile_dpbf16ps_with_undef_amxtype( +; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64 +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N:%.*]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 [[TMP1]], i16 [[N]]) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP5]], x86_amx [[TMP3]], x86_amx [[TMP6]]) +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: ret void +; + %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 + %t1 = bitcast <256 x i32> %t0 to x86_amx + %t2 = load <256 x i32>, <256 x i32>* %pb, align 64 + %t3 = bitcast <256 x i32> %t2 to x86_amx + %t4 = load <256 x i32>, <256 x i32>* %pc, align 64 + %t5 = bitcast <256 x i32> %t4 to x86_amx + %t6 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx undef) + %t7 = bitcast x86_amx %t6 to <256 x i32> + store <256 x i32> %t7, <256 x i32>* %pc, align 64 + ret void +} + define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr { ; CHECK-LABEL: @__tile_stored( ; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0 @@ -330,6 +356,60 @@ ret void } +define dso_local void @__tile_undef_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr { +; CHECK-LABEL: @__tile_undef_stored( +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1 +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP9:%.*]] = load <256 x i32>, <256 x i32>* [[TMP8]], align 64 +; CHECK-NEXT: [[TMP10:%.*]] = shl i64 [[TMP1:%.*]], 32 +; CHECK-NEXT: [[TMP11:%.*]] = ashr exact i64 [[TMP10]], 32 +; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 [[TMP5]], i16 [[TMP7]]) +; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP11]], x86_amx [[TMP12]]) +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0 + %5 = load i16, i16* %4, align 64 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1 + %7 = load i16, i16* %6, align 2 + %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %9 = load <256 x i32>, <256 x i32>* %8, align 64 + %10 = bitcast <256 x i32> %9 to x86_amx + %11 = shl i64 %1, 32 + %12 = ashr exact i64 %11, 32 + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx undef) + ret void +} + +define dso_local void @__tile_zeroinitializer_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr { +; CHECK-LABEL: @__tile_zeroinitializer_stored( +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64 +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1 +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP9:%.*]] = load <256 x i32>, <256 x i32>* [[TMP8]], align 64 +; CHECK-NEXT: [[TMP10:%.*]] = shl i64 [[TMP1:%.*]], 32 +; CHECK-NEXT: [[TMP11:%.*]] = ashr exact i64 [[TMP10]], 32 +; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 [[TMP5]], i16 [[TMP7]]) +; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP11]], x86_amx [[TMP12]]) +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0 + %5 = load i16, i16* %4, align 64 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1 + %7 = load i16, i16* %6, align 2 + %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %9 = load <256 x i32>, <256 x i32>* %8, align 64 + %10 = bitcast <256 x i32> %9 to x86_amx + %11 = shl i64 %1, 32 + %12 = ashr exact i64 %11, 32 + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx zeroinitializer) + ret void +} + declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)