Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -479,6 +479,78 @@ - !ScalarExpression scalar_arg: rhs --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: matmul_transpose_b + cpp_class_name: MatmulTransposeBOp + doc: |- + Performs a matrix multiplication of two 2D inputs with rhs operand transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + - !LinalgOperandDefConfig + name: B + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)> + - !LinalgOperandDefConfig + name: C + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast_signed + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d1, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul cpp_class_name: BatchMatmulOp @@ -653,6 +725,76 @@ - !ScalarExpression scalar_arg: BZp --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: batch_matmul_transpose_b + cpp_class_name: BatchMatmulTransposeBOp + doc: |- + Performs a batched matrix multiplication of two 3D inputs where rhs operand has its non-batch + dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + - !LinalgOperandDefConfig + name: B + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + - !LinalgOperandDefConfig + name: C + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d2, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_reduce_matmul cpp_class_name: BatchReduceMatmulOp Index: mlir/test/Dialect/Linalg/named-ops.mlir =================================================================== --- mlir/test/Dialect/Linalg/named-ops.mlir +++ mlir/test/Dialect/Linalg/named-ops.mlir @@ -1070,3 +1070,25 @@ linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) return } + +// ----- + +// CHECK-LABEL: func @matmul_transpose_b +// CHECK: linalg.matmul_transpose_b +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<3x5xf32>, memref<7x5xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>) +func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %arg2: memref<3x7xf32>) { + linalg.matmul_transpose_b ins(%arg0, %arg1 : memref<3x5xf32>, memref<7x5xf32>) outs(%arg2: memref<3x7xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @batchmatmul_transpose_b +// CHECK: linalg.batch_matmul_transpose_b +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x3x5xf32>, memref<2x7x5xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<2x3x7xf32>) +func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5xf32>, %arg2: memref<2x3x7xf32>) { + linalg.batch_matmul_transpose_b ins(%arg0, %arg1 : memref<2x3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<2x3x7xf32>) + return +}