diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -15,13 +15,13 @@ def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) - C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); } ods_def implements_interface : def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_muli(std_sexti32(A(m, k)), std_sexti32(B(k, n)))); } ods_def @@ -39,13 +39,13 @@ ods_def implements_interface : def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); } ods_def implements_interface : def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) { - x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_muli(std_sexti32(A(m, n)), std_sexti32(y(n)))); } ods_def @@ -63,13 +63,13 @@ ods_def implements_interface : def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); } ods_def implements_interface : def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) { - x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_muli(std_sexti32(y(m)), std_sexti32(A(m, n)))); } ods_def @@ -87,13 +87,13 @@ ods_def implements_interface : def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { - C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); } ods_def implements_interface : def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) { - C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_muli(std_sexti32(A(m)), std_sexti32(B(m)))); } ods_def @@ -112,14 +112,14 @@ implements_interface : def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) { C(b, m, n) = - std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + std_addi(C(b, m, n), std_muli(std_sexti32(A(b, m, k)), std_sexti32(B(b, k, n)))); } diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -373,17 +373,18 @@ // CHECK-SAME: %[[ARG2:[a-z0-9]+]]: memref<4x12xi32> func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12xi32>) { // CHECK-DAG: %[[C0:.*]] = constant 0 : index - // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi8> + // CHECK-DAG: %[[VEC_C0:.*]] = constant dense<0> : vector<4x12xi32> // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x6xi8>, vector<4x6xi8> // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<6x12xi8>, vector<6x12xi8> // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : memref<4x12xi32>, vector<4x12xi32> + // CHECK-DAG: %[[V0_32:.*]] = sexti %[[V0]] : vector<4x6xi8> to vector<4x6xi32> + // CHECK-DAG: %[[V1_32:.*]] = sexti %[[V1]] : vector<6x12xi8> to vector<6x12xi32> // // linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp. // a later canonicalization fuses the add into vector.contract. - // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0]], %[[V1]], %[[VEC_C0]] - // CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8> - // CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32> - // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32> + // CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[V0_32]], %[[V1_32]], %[[VEC_C0]] + // CHECK-SAME: vector<4x6xi32>, vector<6x12xi32> into vector<4x12xi32> + // CHECK: %[[RES:.*]] = addi %[[V2]], %[[C]] : vector<4x12xi32> // CHECK: vector.transfer_write %[[RES]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} // CHECK-SAME: vector<4x12xi32>, memref<4x12xi32> linalg.matmul_i8_i8_i32 ins(%a, %b : memref<4x6xi8>, memref<6x12xi8>)