diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -5139,7 +5139,12 @@ assert(PTy->canLosslesslyBitCastTo(FTy->getParamType(i)) && "Must be able to losslessly bit cast to param"); - ArgValue = Builder.CreateBitCast(ArgValue, PTy); + if (PTy->isX86_AMXTy() || ArgValue->getType()->isX86_AMXTy()) + ArgValue = + Builder.CreateIntrinsic(Intrinsic::x86_vector_amx_cast, + {PTy, ArgValue->getType()}, {ArgValue}); + else + ArgValue = Builder.CreateBitCast(ArgValue, PTy); } Args.push_back(ArgValue); @@ -5163,7 +5168,11 @@ assert(V->getType()->canLosslesslyBitCastTo(RetTy) && "Must be able to losslessly bit cast result type"); - V = Builder.CreateBitCast(V, RetTy); + if (RetTy->isX86_AMXTy() || V->getType()->isX86_AMXTy()) + V = Builder.CreateIntrinsic(Intrinsic::x86_vector_amx_cast, + {RetTy, V->getType()}, {V}); + else + V = Builder.CreateBitCast(V, RetTy); } return RValue::get(V); diff --git a/clang/test/CodeGen/X86/amx_cast.c b/clang/test/CodeGen/X86/amx_cast.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/X86/amx_cast.c @@ -0,0 +1,90 @@ +// RUN: %clang_cc1 %s -O2 -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f -target-feature +amx-int8 \ +// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK + +#include + +char buf[1024]; +#define STRIDE 32 + +char buf2[1024]; + +void test1() { +//CHECK: %0 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #1 +//CHECK: %1 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %0) #1 +//CHECK: %2 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %1) #1 +//CHECK: tail call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %2) #1 + __tile1024i a = {8, 8}; + __tile1024i b = {8, 8}; + + __tile_loadd(&a, buf, STRIDE); + __tile_stored(buf, STRIDE, a); +} + +void test2() { +//CHECK: %0 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> zeroinitializer) #1 +//CHECK: tail call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %0) #1 + __tile1024i a = {8, 8}; + + __tile_stored(buf, STRIDE, a); +} + +#define TILE_SZ 16 +void inner_product(int *A_mem, int *B_mem, int *C_mem, int M, int N, int K) { +//CHECK: for.body6: ; preds = %for.body6.lr.ph, %for.cond.cleanup9 +//CHECK: %indvars.iv200 = phi i64 [ 0, %for.body6.lr.ph ], [ %indvars.iv.next201, %for.cond.cleanup9 ] +//CHECK: %1 = tail call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) #1 +//CHECK: %2 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %1) #1 +//CHECK: %3 = shl nsw i64 %indvars.iv200, 4 +//CHECK: br i1 %cmp8163, label %for.body10.lr.ph, label %for.cond.cleanup9 +//CHECK: for.body10.lr.ph: ; preds = %for.body6 +//CHECK: %add.ptr19 = getelementptr inbounds i32, i32* %B_mem, i64 %3 +//CHECK: br label %for.body10 +//CHECK: for.cond.cleanup9: ; preds = %for.body10, %for.body6 +//CHECK: %c.sroa.8127.2.lcssa = phi <256 x i32> [ %2, %for.body6 ], [ %18, %for.body10 ] +//CHECK: %add.ptr31 = getelementptr inbounds i32, i32* %add.ptr28, i64 %3 +//CHECK: %4 = bitcast i32* %add.ptr31 to i8* +//CHECK: %5 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %c.sroa.8127.2.lcssa) #1 +//CHECK: tail call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* %4, i64 %mul24, x86_amx %5) #1 +//CHECK: %indvars.iv.next201 = add nuw nsw i64 %indvars.iv200, 1 +//CHECK: %exitcond205.not = icmp eq i64 %indvars.iv.next201, %wide.trip.count204 +//CHECK: br i1 %exitcond205.not, label %for.cond.cleanup5, label %for.body6, !llvm.loop !4 +//CHECK: for.body10: ; preds = %for.body10.lr.ph, %for.body10 +//CHECK: %indvars.iv = phi i64 [ 0, %for.body10.lr.ph ], [ %indvars.iv.next, %for.body10 ] +//CHECK: %c.sroa.8127.2164 = phi <256 x i32> [ %2, %for.body10.lr.ph ], [ %18, %for.body10 ] +//CHECK: %6 = shl nsw i64 %indvars.iv, 4 +//CHECK: %add.ptr14 = getelementptr inbounds i32, i32* %add.ptr, i64 %6 +//CHECK: %7 = bitcast i32* %add.ptr14 to i8* +//CHECK: %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %7, i64 %mul15) #1 +//CHECK: %9 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %8) #1 +//CHECK: %10 = mul nsw i64 %6, %conv23 +//CHECK: %add.ptr22 = getelementptr inbounds i32, i32* %add.ptr19, i64 %10 +//CHECK: %11 = bitcast i32* %add.ptr22 to i8* +//CHECK: %12 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* %11, i64 %mul24) #1 +//CHECK: %13 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %12) #1 +//CHECK: %14 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %c.sroa.8127.2164) #1 +//CHECK: %15 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %9) #1 +//CHECK: %16 = tail call x86_amx @llvm.x86.vector.amx.cast.x86amx.v256i32(<256 x i32> %13) #1 +//CHECK: %17 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %14, x86_amx %15, x86_amx %16) #1 +//CHECK: %18 = tail call <256 x i32> @llvm.x86.vector.amx.cast.v256i32.x86amx(x86_amx %17) #1 +//CHECK: %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 +//CHECK: %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count +//CHECK: br i1 %exitcond.not, label %for.cond.cleanup9, label %for.body10, !llvm.loop !5 + + const int m = M / TILE_SZ; + const int n = N / TILE_SZ; + const int k = K / TILE_SZ; + + for (int i = 0; i < m; i++) + for (int j = 0; j < n; j++) { + __tile1024i c = {TILE_SZ, TILE_SZ*sizeof(int)}; + __tile_zero(&c); + for (int l = 0; l < k; l++) { + __tile1024i a = {TILE_SZ, TILE_SZ*sizeof(int)}; + __tile1024i b = {TILE_SZ, TILE_SZ*sizeof(int)}; + __tile_loadd(&a, A_mem+(i*TILE_SZ)*K+l*TILE_SZ, K*sizeof(int)); + __tile_loadd(&b, B_mem+(l*TILE_SZ)*N+j*TILE_SZ, N*sizeof(int)); + __tile_dpbssd(&c, a, b); + } + __tile_stored(C_mem+(i*TILE_SZ)*M+j*TILE_SZ, N*sizeof(int), c); + } +} diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5085,6 +5085,8 @@ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty], []>; + def int_x86_vector_amx_cast : + Intrinsic<[llvm_any_ty], [llvm_any_ty], [IntrNoMem]>; } //===----------------------------------------------------------------------===//