diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -483,7 +483,8 @@ name: matmul_transpose_b cpp_class_name: MatmulTransposeBOp doc: |- - Performs a matrix multiplication of two 2D inputs with rhs operand transposed. + Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -729,8 +730,8 @@ name: batch_matmul_transpose_b cpp_class_name: BatchMatmulTransposeBOp doc: |- - Performs a batched matrix multiplication of two 3D inputs where rhs operand has its non-batch - dimensions transposed. + Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -4722,4 +4723,3 @@ scalar_const: '2.3283063999999999E-10 : f64' - !ScalarExpression scalar_arg: min - diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -116,6 +116,22 @@ TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) +@linalg_structured_op +def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.N, S.K), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) + + @linalg_structured_op def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K, S.N), @@ -151,6 +167,23 @@ U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +@linalg_structured_op +def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.N, S.K), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, + D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.n, D.k]) + + @linalg_structured_op def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K, S.N),