diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -548,6 +548,76 @@ - !ScalarExpression scalar_arg: B --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: batch_reduce_matmul + cpp_class_name: BatchReduceMatmulOp + doc: |- + Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + - !LinalgOperandDefConfig + name: B + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + - !LinalgOperandDefConfig + name: C + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)> + iterator_types: + - reduction + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_batch_matmul cpp_class_name: QuantizedBatchMatmulOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,6 +150,20 @@ TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +@linalg_structured_op +def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( + U, B[D.b, D.k, D.n]) @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -248,3 +248,27 @@ // CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0_F32]], %[[BBARG1_F32]] // CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] // CHECK: linalg.yield %[[ADD]] : f32 + +// ----- + +func.func @batch_reduce_gemm(%lhs: memref<7x8x9xf32>, %rhs: memref<7x9x8xf32>, %out: memref<8x8xf32>) { + linalg.batch_reduce_matmul ins(%lhs, %rhs: memref<7x8x9xf32>, memref<7x9x8xf32>) + outs(%out: memref<8x8xf32>) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> + +// CHECK: @batch_reduce_gemm + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<7x8x9xf32>, memref<7x9x8xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<8x8xf32> +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK: %[[MUL:.+]] = arith.mulf %[[BBARG0]], %[[BBARG1]] : f32 +// CHECK: %[[ADD:.+]] = arith.addf %[[BBARG2]], %[[MUL]] : f32 +// CHECK: linalg.yield %[[ADD]] : f32 diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -794,3 +794,23 @@ }) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = array, strides = dense<1> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor return %0 : tensor } + +// ----- + +func.func @batch_reduce_matmul(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> { + // CHECK: %{{.+}} = linalg.batch_reduce_matmul + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<8x128x256xf32>, tensor<8x256x512xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<128x512xf32>) -> tensor<128x512xf32> + %0 = linalg.batch_reduce_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32> + return %0: tensor<128x512xf32> +} + +// ----- + +func.func @batch_reduce_matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + // CHECK: linalg.batch_reduce_matmul + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) + // CHECK-SAME: outs(%{{.+}} : memref) + linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref, memref) outs(%arg2: memref) + return +}