diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -34,6 +34,41 @@ }]; let cppNamespace = "::mlir::linalg"; let verify = [{ return detail::verifyContractionInterface($_op); }]; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps the correspond to a + row-major matmul operation. + }], + /*retTy=*/"bool", + /*methodName=*/"isRowMajorMatmul", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isRowMajorMatmul($_op.indexing_maps()); + }]>, + InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps the correspond to a + column-major matmul operation. + }], + /*retTy=*/"bool", + /*methodName=*/"isColumnMajorMatmul", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isColumnMajorMatmul($_op.indexing_maps()); + }]>, + InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps the correspond to a + row-major batch matmul operation. + }], + /*retTy=*/"bool", + /*methodName=*/"isRowMajorBatchMatmul", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isRowMajorBatchMatmul($_op.indexing_maps()); + }]>, + ]; } // The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -18,30 +18,118 @@ C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); } +ods_def +implements_interface : +def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) { + C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); +} + +ods_def +implements_interface : +def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) { + C(m, n) = std_addi(std_muli(A(m, k), B(k, n))); +} + ods_def implements_interface : def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { x(m) = std_addf(std_mulf(A(m, n), y(n))); } +ods_def +implements_interface : +def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { + x(m) = std_addi(std_sexti32(std_muli(A(m, n), y(n)))); +} + +ods_def +implements_interface : +def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) { + x(m) = std_addi(std_sexti32(std_muli(A(m, n), y(n)))); +} + +ods_def +implements_interface : +def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) { + x(m) = std_addi(std_muli(A(m, n), y(n))); +} + ods_def implements_interface : def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { x(n) = std_addf(std_mulf(y(m), A(m, n))); } +ods_def +implements_interface : +def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { + x(n) = std_addi(std_sexti32(std_muli(y(m), A(m, n)))); +} + +ods_def +implements_interface : +def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) { + x(n) = std_addi(std_sexti32(std_muli(y(m), A(m, n)))); +} + + +ods_def +implements_interface : +def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) { + x(n) = std_addi(std_muli(y(m), A(m, n))); +} + ods_def implements_interface : def dot(A: f32(M), B: f32(M)) -> (C: f32()) { C() = std_addf(std_mulf(A(m), B(m))); } +ods_def +implements_interface : +def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { + C() = std_addi(std_sexti32(std_muli(A(m), B(m)))); +} + +ods_def +implements_interface : +def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) { + C() = std_addi(std_sexti32(std_muli(A(m), B(m)))); +} + + +ods_def +implements_interface : +def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) { + C() = std_addi(std_muli(A(m), B(m))); +} + + ods_def implements_interface : def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); } +ods_def +implements_interface : +def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { + C(b, m, n) = std_addi(std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); +} + +ods_def +implements_interface : +def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) { + C(b, m, n) = std_addi(std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); +} + + +ods_def +implements_interface : +def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) { + C(b, m, n) = std_addi(std_muli(A(b, m, k), B(b, k, n))); +} + ods_def: def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) { O(w) = std_addf(std_mulf(I(w + kw), K(kw)));