diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -934,9 +934,8 @@ Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); IRBuilder<> Builder(ST); - // Use the maximum column as stride. It must be the same with load - // stride. - Value *Stride = Builder.getInt64(64); + // Stride should be equal to col(measured by bytes) + Value *Stride = Builder.CreateSExt(Col, Builder.getInt64Ty()); Value *I8Ptr = Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); std::array Args = {Row, Col, I8Ptr, Stride, Tile}; diff --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll --- a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll +++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll @@ -74,7 +74,8 @@ ; CHECK-NEXT: [[TMP1:%.*]] = sext i16 [[M]] to i64 ; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[TMP0]], i64 [[TMP1]], x86_amx [[T1]]) ; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP0]], align 1024 -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[OUT:%.*]], i64 64, x86_amx [[T1]]) +; CHECK-NEXT: [[TMP3:%.*]] = sext i16 [[M]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], ptr [[OUT:%.*]], i64 [[TMP3]], x86_amx [[T1]]) ; CHECK-NEXT: ret <256 x i32> [[TMP2]] ; entry: @@ -129,7 +130,8 @@ ; CHECK-NEXT: [[TMP8:%.*]] = ashr exact i64 [[TMP7]], 32 ; CHECK-NEXT: [[TMP9:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP1:%.*]], i64 [[TMP8]]) ; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP0]], i64 0, i32 2 -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64, x86_amx [[TMP9]]) +; CHECK-NEXT: [[TMP11:%.*]] = sext i16 [[TMP6]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP11]], x86_amx [[TMP9]]) ; CHECK-NEXT: ret void ; %4 = load i16, ptr %0, align 64 @@ -159,7 +161,8 @@ ; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], ptr [[TMP2]], i64 0, i32 2 ; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP6]], ptr [[TMP14]], i64 64) ; CHECK-NEXT: [[TMP16:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP4]], i16 [[TMP6]], i16 [[TMP8]], x86_amx [[TMP11]], x86_amx [[TMP13]], x86_amx [[TMP15]]) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 64, x86_amx [[TMP16]]) +; CHECK-NEXT: [[TMP17:%.*]] = sext i16 [[TMP6]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP4]], i16 [[TMP6]], ptr [[TMP10]], i64 [[TMP17]], x86_amx [[TMP16]]) ; CHECK-NEXT: ret void ; %4 = load i16, ptr %1, align 64 @@ -189,7 +192,8 @@ ; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64) ; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64) ; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]]) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, ptr %pa, align 64 @@ -211,7 +215,8 @@ ; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64) ; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64) ; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]]) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, ptr %pa, align 64 @@ -233,7 +238,8 @@ ; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64) ; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64) ; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]]) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, ptr %pa, align 64 @@ -255,7 +261,8 @@ ; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], ptr [[PB:%.*]], i64 64) ; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], ptr [[PC:%.*]], i64 64) ; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP4]], x86_amx [[TMP2]], x86_amx [[TMP3]]) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP5:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], ptr [[PC]], i64 [[TMP5]], x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, ptr %pa, align 64