diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -700,11 +700,12 @@ class X86LowerAMXCast { Function &Func; + std::unique_ptr DT; public: - X86LowerAMXCast(Function &F) : Func(F) {} + X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} void combineCastStore(IntrinsicInst *Cast, StoreInst *ST); - void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); + bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); bool combineLdSt(SmallVectorImpl &Casts); bool combineAMXcast(TargetLibraryInfo *TLI); bool transformAMXCast(IntrinsicInst *AMXCast); @@ -942,7 +943,8 @@ // --> // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %p, i64 64) -void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { +bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { + bool EraseLoad = true; Value *Row = nullptr, *Col = nullptr; Use &U = *(Cast->use_begin()); unsigned OpNo = U.getOperandNo(); @@ -950,18 +952,37 @@ // TODO: If it is cast intrinsic or phi node, we can propagate the // shape information through def-use chain. if (!isAMXIntrinsic(II)) - return; + return false; std::tie(Row, Col) = getShape(II, OpNo); IRBuilder<> Builder(LD); // Use the maximun column as stride. Value *Stride = Builder.getInt64(64); - Value *I8Ptr = - Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + Value *I8Ptr; + + // To save compiling time, we create doninator tree when it is really + // needed. + if (!DT) + DT.reset(new DominatorTree(Func)); + if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) { + // store the value to stack and reload it from stack before cast. + auto *AllocaAddr = + createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType()); + Builder.SetInsertPoint(&*std::next(LD->getIterator())); + Builder.CreateStore(LD, AllocaAddr); + + Builder.SetInsertPoint(Cast); + I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); + EraseLoad = false; + } else { + I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + } std::array Args = {Row, Col, I8Ptr, Stride}; Value *NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); Cast->replaceAllUsesWith(NewInst); + + return EraseLoad; } bool X86LowerAMXCast::combineLdSt(SmallVectorImpl &Casts) { @@ -995,10 +1016,11 @@ // --> // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, // i8* %p, i64 64) - combineLoadCast(cast(Cast), Load); - // Set the operand is null so that load instruction can be erased. - Cast->setOperand(0, nullptr); - Load->eraseFromParent(); + if (combineLoadCast(cast(Cast), Load)) { + // Set the operand is null so that load instruction can be erased. + Cast->setOperand(0, nullptr); + Load->eraseFromParent(); + } } } return Change; @@ -1198,6 +1220,7 @@ TargetMachine *TM = &getAnalysis().getTM(); TargetLibraryInfo *TLI = &getAnalysis().getTLI(F); + X86LowerAMXCast LAC(F); C |= LAC.combineAMXcast(TLI); // There might be remaining AMXcast after combineAMXcast and they should be diff --git a/llvm/test/CodeGen/X86/AMX/amx-combine.ll b/llvm/test/CodeGen/X86/AMX/amx-combine.ll --- a/llvm/test/CodeGen/X86/AMX/amx-combine.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-combine.ll @@ -18,9 +18,9 @@ ; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 ; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) ; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64, x86_amx [[T1]]) -; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024 +; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024 ; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64, x86_amx [[T1]]) -; CHECK-NEXT: ret <256 x i32> [[TMP3]] +; CHECK-NEXT: ret <256 x i32> [[TMP2]] ; %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64) %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1) @@ -30,8 +30,8 @@ define void @combine_load(ptr%p, ptr%p2) { ; CHECK-LABEL: @combine_load( -; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]]) +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]]) ; CHECK-NEXT: ret void ; %t1 = load <256 x i32>, ptr %p, align 64 @@ -42,9 +42,9 @@ define void @combine_cast_across_store(ptr%p, ptr%p2) { ; CHECK-LABEL: @combine_cast_across_store( -; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64) +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64) ; CHECK-NEXT: store <256 x i32> zeroinitializer, ptr [[P]], align 64 -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]]) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]]) ; CHECK-NEXT: ret void ; %t1 = load <256 x i32>, ptr %p, align 64 @@ -59,8 +59,8 @@ ; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 ; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64 ; CHECK-NEXT: store <256 x i32> [[T1]], ptr [[TMP1]], align 1024 -; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP3]]) +; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]]) ; CHECK-NEXT: ret <256 x i32> [[T1]] ; %t1 = load <256 x i32>, ptr %p, align 64 @@ -75,9 +75,9 @@ ; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 ; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64 ; CHECK-NEXT: store <256 x i32> [[T1]], ptr [[TMP1]], align 1024 -; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP3]]) -; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP3]], x86_amx [[TMP3]], x86_amx [[TMP3]]) +; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]]) +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP2]], x86_amx [[TMP2]], x86_amx [[TMP2]]) ; CHECK-NEXT: ret <256 x i32> [[T1]] ; %t1 = load <256 x i32>, ptr %p, align 64 @@ -88,6 +88,48 @@ ret <256 x i32> %t3 } +; the shape is loaded after tile. +%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }> +define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, ptr byval(%struct.__tile1024i_str) align 64 %b, ptr byval(%struct.__tile1024i_str) align 64 %c) { +; CHECK-LABEL: @test_tile_dpbssd( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2 +; CHECK-NEXT: [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2 +; CHECK-NEXT: [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64 +; CHECK-NEXT: [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64 +; CHECK-NEXT: store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024 +; CHECK-NEXT: [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64 +; CHECK-NEXT: [[A_COL_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 2 +; CHECK-NEXT: [[A_COL:%.*]] = load i16, ptr [[A_COL_PTR]], align 2 +; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[A_COL]], 4 +; CHECK-NEXT: [[A_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 64 +; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], ptr [[A_TILE_PTR]], i64 64) +; CHECK-NEXT: [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[B_ROW]], ptr [[TMP0]], i64 64) +; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP3]], x86_amx [[TMP2]], x86_amx [[TMP4]]) +; CHECK-NEXT: ret void +; +entry: + %b.row.ptr= getelementptr inbounds i8, ptr %b, i64 2 + %b.row = load i16, ptr %b.row.ptr, align 2 + %b.tile.ptr = getelementptr inbounds i8, ptr %b, i64 64 + %b.tile = load <256 x i32>, ptr %b.tile.ptr, align 64 + %a.row = load i16, ptr %a, align 64 + %a.col.ptr = getelementptr inbounds i8, ptr %a, i64 2 + %a.col = load i16, ptr %a.col.ptr, align 2 + %a.tile.ptr = getelementptr inbounds i8, ptr %a, i64 64 + %a.tile = load <256 x i32>, ptr %a.tile.ptr, align 64 + %c.tile.ptr = getelementptr inbounds %struct.__tile1024i_str, ptr %c, i64 0, i32 3 + %c.tile = load <256 x i32>, ptr %c.tile.ptr, align 64 + %c.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %c.tile) + %a.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %a.tile) + %b.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %b.tile) + %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.row, i16 %a.col, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx) + ret void +} + declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>) declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx) declare x86_amx @llvm.x86.tilezero.internal(i16, i16)