diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -95,6 +95,76 @@ return std::make_pair(Row, Col); } +static std::pair getShape(PHINode *PN) { + PHINode *InComingPN = nullptr; + // Iterate the incoming value to find an AMX intrinsic. + for (Value *IncValue : PN->incoming_values()) { + if (dyn_cast(IncValue)) + continue; + // The first 2 operand is row and column when define an tile + if (auto *II = dyn_cast(IncValue)) { + Value *Row = II->getArgOperand(0); + Value *Col = II->getArgOperand(1); + return std::make_pair(Row, Col); + } + if (!InComingPN) + InComingPN = dyn_cast(IncValue); + } + if (InComingPN) + return getShape(InComingPN); + + // The incoming value is from load. + return std::make_pair(nullptr, nullptr); +} + +// %1 = load x86_amx, x86_amx* %0, align 64 +// %2 = call x86_amx @llvm.x86.tdpbssd.internal(%1, %1, %1, ...) +// --> +// %1 = call x86_amx @llvm.x86.tileloadd64.internal() +// %2 = call x86_amx @llvm.x86.tdpbssd.internal(%1, %1, %1, ...) +static void transformTileLoad(LoadInst *LD) { + Value *Row = nullptr, *Col = nullptr; + PHINode *PN = nullptr; + for (auto UI = LD->use_begin(), UE = LD->use_end(); UI != UE;) { + Use &U = *(UI++); + if (dyn_cast(U.getUser())) + continue; + if (auto *II = dyn_cast(U.getUser())) { + unsigned OpNo = U.getOperandNo(); + std::tie(Row, Col) = getShape(II, OpNo); + break; + } + // %1 = phi x86_amx [ %2, %for.body14 ], [ %3, %for.body24 ] + if (!PN) { + PN = dyn_cast(U.getUser()); + continue; + } + // store x86_amx %9, x86_amx* %addr, align 64 + } + // No user is AMX intrinsic, we need get shape from PHI node. + if (!Row) { + if (PN) + std::tie(Row, Col) = getShape(PN); + // All users are store instruction. Transform them to vector + // load/store. + else { + // TODO perform transform. + return; + } + } + + IRBuilder<> Builder(LD); + // Use the maximun column as stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride}; + + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); + LD->replaceAllUsesWith(NewInst); +} + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // --> @@ -126,13 +196,17 @@ // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, // %stride64, %13) static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { - + Value *Row, *Col; Value *Tile = Bitcast->getOperand(0); - auto *II = cast(Tile); - // Tile is output from AMX intrinsic. The first operand of the - // intrinsic is row, the second operand of the intrinsic is column. - Value *Row = II->getOperand(0); - Value *Col = II->getOperand(1); + if (auto *II = dyn_cast(Tile)) { + // Tile is output from AMX intrinsic. The first operand of the + // intrinsic is row, the second operand of the intrinsic is column. + Row = II->getOperand(0); + Col = II->getOperand(1); + } else if (auto *PN = dyn_cast(Tile)) + std::tie(Row, Col) = getShape(PN); + // TODO else the def is load, transform load/store to vector load/store; + IRBuilder<> Builder(ST); // Use the maximum column as stride. It must be the same with load // stride. @@ -215,6 +289,26 @@ return true; } +// %addr = bitcast <256 x i32>* %tile to x86_amx* +// store x86_amx %9, x86_amx* %addr, align 64 +// --> +// call void @llvm.x86.tilestored64.internal(%row, %col, %addr, +// %stride64, %9) +static void transformTileStore(StoreInst *ST) { + auto *II = cast(ST->getValueOperand()); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + IRBuilder<> Builder(ST); + // Use the maximum column as stride. It must be the same with load + // stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride, + ST->getValueOperand()}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); +} + namespace { class X86LowerAMXType { Function &Func; @@ -236,6 +330,36 @@ continue; Value *Src = Bitcast->getOperand(0); + Type *Ty = Bitcast->getType(); + + if (Ty->isPointerTy() && + cast(Ty)->getElementType()->isX86_AMXTy()) { + for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); + UI != UE;) { + Value *I = (UI++)->getUser(); + auto *LD = dyn_cast(I); + // %0 = bitcast <256 x i32>* %tile to x86_amx* + // %1 = load x86_amx, x86_amx* %0, align 64 + if (LD) { + transformTileLoad(LD); + DeadInsts.push_back(LD); + } + auto *ST = dyn_cast(I); + if (ST) { + // %addr = bitcast <256 x i32>* %tile to x86_amx* + // store x86_amx %9, x86_amx* %addr, align 64 + // --> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %9) + transformTileStore(ST); + DeadInsts.push_back(ST); + } + } + // If the dst type is <256 x i32>*, it is valid intruction. + // %0 = bitcast x86_amx* %tile to <256 x i32>* + // %1 = load <256 x i32>, <256 x i32>* %0, align 64 + // store <256 x i32> %2, <256 x i32>* %0, align 64 + } if (Bitcast->getType()->isX86_AMXTy()) { if (Bitcast->user_empty()) { DeadInsts.push_back(Bitcast); diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -6,6 +6,38 @@ @buf = dso_local global [1024 x i8] zeroinitializer, align 64 @buf2 = dso_local global [1024 x i8] zeroinitializer, align 64 +define dso_local void @test_amx_store(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_amx_store( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) +; CHECK-NEXT: [[ADDR:%.*]] = bitcast <256 x i32>* [[IN:%.*]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[ADDR]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP0]], i64 64, x86_amx [[T0]]) +; CHECK-NEXT: ret void +; +entry: + %t0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s) #3 + %addr = bitcast <256 x i32>* %in to x86_amx* + store x86_amx %t0, x86_amx* %addr, align 64 + ret void +} + +define dso_local void @test_amx_load(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_amx_load( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T0:%.*]] = bitcast <256 x i32>* [[IN:%.*]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[T0]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP1]]) +; CHECK-NEXT: ret void +; +entry: + %t0 = bitcast <256 x i32>* %in to x86_amx* + %t1 = load x86_amx, x86_amx* %t0, align 64 + call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t1) #3 + ret void +} + ; test bitcast x86_amx to <256 x i32> define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) { ; CHECK-LABEL: @test_user_empty( @@ -225,6 +257,66 @@ ret void } +define linkonce_odr dso_local void @test_amxptr(<256 x i32>* %arrayidx16, <256 x i32>* %arrayidx29, <256 x i32>* %arrayidx35) { +; CHECK-LABEL: @test_amxptr( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[FOR_COND9:%.*]] +; CHECK: for.cond9: +; CHECK-NEXT: br i1 undef, label [[FOR_BODY14:%.*]], label [[EXIT:%.*]] +; CHECK: for.body14: +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <256 x i32>* [[ARRAYIDX16:%.*]] to x86_amx* +; CHECK-NEXT: [[TMP1:%.*]] = bitcast x86_amx* [[TMP0]] to i8* +; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 1, i16 4, i8* [[TMP1]], i64 64) +; CHECK-NEXT: br label [[FOR_COND18:%.*]] +; CHECK: for.cond18: +; CHECK-NEXT: [[TMP3:%.*]] = phi x86_amx [ [[TMP2]], [[FOR_BODY14]] ], [ [[T11:%.*]], [[FOR_BODY24:%.*]] ] +; CHECK-NEXT: br i1 undef, label [[FOR_BODY24]], label [[FOR_COND_CLEANUP23:%.*]] +; CHECK: for.cond.cleanup23: +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[ARRAYIDX16]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 1, i16 4, i8* [[TMP4]], i64 64, x86_amx [[TMP3]]) +; CHECK-NEXT: br label [[FOR_COND9]] +; CHECK: for.body24: +; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[ARRAYIDX29:%.*]] to i8* +; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 1, i16 4, i8* [[TMP5]], i64 64) +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[ARRAYIDX35:%.*]] to i8* +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 4, i8* [[TMP7]], i64 64) +; CHECK-NEXT: [[T11]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 1, i16 4, i16 4, x86_amx [[TMP3]], x86_amx [[TMP6]], x86_amx [[TMP8]]) +; CHECK-NEXT: br label [[FOR_COND18]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + br label %for.cond9 + +for.cond9: ; preds = %for.cond.cleanup23, %entry + br i1 undef, label %for.body14, label %exit + +for.body14: ; preds = %for.cond9 + %0 = bitcast <256 x i32>* %arrayidx16 to x86_amx* + %t51 = load x86_amx, x86_amx* %0, align 64 + br label %for.cond18 + +for.cond18: ; preds = %for.body24, %for.body14 + %1 = phi x86_amx [ %t51, %for.body14 ], [ %t11, %for.body24 ] + br i1 undef, label %for.body24, label %for.cond.cleanup23 + +for.cond.cleanup23: ; preds = %for.cond18 + %2 = bitcast x86_amx %1 to <256 x i32> + store <256 x i32> %2, <256 x i32>* %arrayidx16, align 64 + br label %for.cond9 + +for.body24: ; preds = %for.cond18 + %t6 = load <256 x i32>, <256 x i32>* %arrayidx29, align 64 + %t7 = load <256 x i32>, <256 x i32>* %arrayidx35, align 64 + %t9 = bitcast <256 x i32> %t6 to x86_amx + %t10 = bitcast <256 x i32> %t7 to x86_amx + %t11 = call x86_amx @llvm.x86.tdpbssd.internal(i16 1, i16 4, i16 4, x86_amx %1, x86_amx %t9, x86_amx %t10) + br label %for.cond18 + +exit: ; preds = %for.cond9 + ret void +} + declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)