diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -15,13 +15,13 @@ def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) - C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); } ods_def implements_interface : def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); } ods_def @@ -39,13 +39,13 @@ ods_def implements_interface : def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); } ods_def implements_interface : def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); } ods_def @@ -63,13 +63,13 @@ ods_def implements_interface : def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); } ods_def implements_interface : def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); } ods_def @@ -87,13 +87,13 @@ ods_def implements_interface : def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { - C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); } ods_def implements_interface : def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) { - C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); } ods_def @@ -112,14 +112,14 @@ implements_interface : def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); }