diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -95,6 +95,29 @@ return std::make_pair(Row, Col); } +// %1 = load x86_amx, x86_amx* %0, align 64 +// %2 = call x86_amx @llvm.x86.tdpbssd.internal(%1, %1, %1, ...) +// --> +// %1 = call x86_amx @llvm.x86.tileloadd64.internal() +// %2 = call x86_amx @llvm.x86.tdpbssd.internal(%1, %1, %1, ...) +static void transformTileLoad(LoadInst *LD) { + Value *Row = nullptr, *Col = nullptr; + Use &U = *(LD->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = cast(U.getUser()); + std::tie(Row, Col) = getShape(II, OpNo); + IRBuilder<> Builder(LD); + // Use the maximun column as stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride}; + + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); + LD->replaceAllUsesWith(NewInst); +} + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // --> @@ -215,6 +238,26 @@ return true; } +// %addr = bitcast <256 x i32>* %tile to x86_amx* +// store x86_amx %9, x86_amx* %addr, align 64 +// --> +// call void @llvm.x86.tilestored64.internal(%row, %col, %addr, +// %stride64, %9) +static void transformTileStore(StoreInst *ST) { + auto *II = cast(ST->getValueOperand()); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + IRBuilder<> Builder(ST); + // Use the maximum column as stride. It must be the same with load + // stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride, + ST->getValueOperand()}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); +} + namespace { class X86LowerAMXType { Function &Func; @@ -236,6 +279,36 @@ continue; Value *Src = Bitcast->getOperand(0); + Type *Ty = Bitcast->getType(); + + if (Ty->isPointerTy() && + cast(Ty)->getElementType()->isX86_AMXTy()) { + for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); + UI != UE;) { + Value *I = (UI++)->getUser(); + auto *LD = dyn_cast(I); + // %0 = bitcast <256 x i32>* %tile to x86_amx* + // %1 = load x86_amx, x86_amx* %0, align 64 + if (LD) { + transformTileLoad(LD); + DeadInsts.push_back(LD); + } + auto *ST = dyn_cast(I); + if (ST) { + // %addr = bitcast <256 x i32>* %tile to x86_amx* + // store x86_amx %9, x86_amx* %addr, align 64 + // --> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %9) + transformTileStore(ST); + DeadInsts.push_back(ST); + } + } + // If the dst type is <256 x i32>*, it is valid intruction. + // %0 = bitcast x86_amx* %tile to <256 x i32>* + // %1 = load <256 x i32>, <256 x i32>* %0, align 64 + // store <256 x i32> %2, <256 x i32>* %0, align 64 + } if (Bitcast->getType()->isX86_AMXTy()) { if (Bitcast->user_empty()) { DeadInsts.push_back(Bitcast); diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -6,6 +6,38 @@ @buf = dso_local global [1024 x i8] zeroinitializer, align 64 @buf2 = dso_local global [1024 x i8] zeroinitializer, align 64 +define dso_local void @test_amx_store(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_amx_store( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3:#.*]] +; CHECK-NEXT: [[ADDR:%.*]] = bitcast <256 x i32>* [[IN:%.*]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[ADDR]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP0]], i64 64, x86_amx [[T0]]) +; CHECK-NEXT: ret void +; +entry: + %t0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s) #3 + %addr = bitcast <256 x i32>* %in to x86_amx* + store x86_amx %t0, x86_amx* %addr, align 64 + ret void +} + +define dso_local void @test_amx_load(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_amx_load( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T0:%.*]] = bitcast <256 x i32>* [[IN:%.*]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[T0]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP1]]) [[ATTR3]] +; CHECK-NEXT: ret void +; +entry: + %t0 = bitcast <256 x i32>* %in to x86_amx* + %t1 = load x86_amx, x86_amx* %t0, align 64 + call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t1) #3 + ret void +} + ; test bitcast x86_amx to <256 x i32> define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) { ; CHECK-LABEL: @test_user_empty(