diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -400,6 +400,79 @@ - !ScalarExpression scalar_arg: BZp --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: matmul_transpose_b + cpp_class_name: MatmulTransposeBOp + doc: |- + Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + - !LinalgOperandDefConfig + name: B + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + - !LinalgOperandDefConfig + name: C + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast_signed + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d cpp_class_name: Mmt4DOp @@ -480,10 +553,10 @@ scalar_arg: rhs --- !LinalgOpConfig metadata: !LinalgOpMetadata - name: matmul_transpose_b - cpp_class_name: MatmulTransposeBOp + name: batch_matmul + cpp_class_name: BatchMatmulOp doc: |- - Performs a matrix multiplication of two 2D inputs with rhs operand transposed. + Performs a batched matrix multiplication of two 3D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -495,29 +568,26 @@ name: A kind: input_tensor type_var: T1 - shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - !LinalgOperandDefConfig name: B kind: input_tensor type_var: T2 - shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> - !LinalgOperandDefConfig name: C kind: output_tensor type_var: U - shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> - - !LinalgOperandDefConfig - name: cast - kind: type_fn_attr - default_fn: cast_signed + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> - - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)> - - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> iterator_types: - parallel - parallel + - parallel - reduction assignments: - !ScalarAssign @@ -537,7 +607,7 @@ - !ScalarExpression scalar_fn: kind: type - attr_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -545,17 +615,18 @@ - !ScalarExpression scalar_fn: kind: type - attr_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression scalar_arg: B --- !LinalgOpConfig metadata: !LinalgOpMetadata - name: batch_matmul - cpp_class_name: BatchMatmulOp + name: batch_matmul_transpose_b + cpp_class_name: BatchMatmulTransposeBOp doc: |- - Performs a batched matrix multiplication of two 3D inputs. + Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -572,7 +643,7 @@ name: B kind: input_tensor type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> - !LinalgOperandDefConfig name: C kind: output_tensor @@ -581,7 +652,7 @@ indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)> - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> iterator_types: - parallel @@ -725,76 +796,6 @@ - !ScalarExpression scalar_arg: BZp --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: batch_matmul_transpose_b - cpp_class_name: BatchMatmulTransposeBOp - doc: |- - Performs a batched matrix multiplication of two 3D inputs where rhs operand has its non-batch - dimensions transposed. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - implements: - - LinalgContractionOpInterface -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: A - kind: input_tensor - type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> - - !LinalgOperandDefConfig - name: B - kind: input_tensor - type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> - - !LinalgOperandDefConfig - name: C - kind: output_tensor - type_var: U - shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)> - - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> - iterator_types: - - parallel - - parallel - - parallel - - reduction - assignments: - - !ScalarAssign - arg: C - value: !ScalarExpression - scalar_fn: - kind: binary - fn_name: add - operands: - - !ScalarExpression - scalar_arg: C - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: mul - operands: - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: A - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: B ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_reduce_matmul cpp_class_name: BatchReduceMatmulOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -107,6 +107,22 @@ ) +@linalg_structured_op +def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.N, S.K), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) + + @linalg_structured_op def mmt4d( lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), @@ -148,6 +164,23 @@ ) +@linalg_structured_op +def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.N, S.K), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, + D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.n, D.k]) + + @linalg_structured_op def quantized_batch_matmul( A=TensorDef(T1, Batch, S.M, S.K),