diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1,12 +1,12 @@ --- !LinalgOpConfig metadata: !LinalgOpMetadata - name: polymorphic_matmul - cpp_op_name: PolymorphicMatmulOp + name: matmul + cpp_op_name: MatmulOp doc: |- - Type polymorphic matrix multiplication. + Performs a matrix multiplacation of two 2D inputs. - This op is presently here to test a new path for generation and will replace - the existing 'matmul' op when ready. Do not use. + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. implements: - LinalgContractionOpInterface structured_op: !LinalgStructuredOpConfig @@ -60,4 +60,249 @@ operands: - !ScalarExpression scalar_arg: B +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: batch_matmul + cpp_op_name: BatchMatmulOp + doc: |- + Performs a batched matrix multiplacation of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: A + usage: input + shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + element_type_var: T1 + - ! + name: B + usage: input + shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)> + element_type_var: T2 + - ! + name: C + usage: output + shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + element_type_var: U + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: matvec + cpp_op_name: MatvecOp + doc: |- + Performs a matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: A + usage: input + shape: affine_map<()[s0, s1] -> (s0, s1)> + element_type_var: T1 + - ! + name: y + usage: input + shape: affine_map<()[s0, s1] -> (s1)> + element_type_var: T2 + - ! + name: x + usage: output + shape: affine_map<()[s0, s1] -> (s0)> + element_type_var: U + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> + - affine_map<(d0, d1)[s0, s1] -> (d1)> + - affine_map<(d0, d1)[s0, s1] -> (d0)> + iterator_types: + - parallel + - reduction + assignments: + - !ScalarAssign + arg: x + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: x + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: y +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: vecmat + cpp_op_name: VecmatOp + doc: |- + Performs a vector-matrix multiplacation. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: y + usage: input + shape: affine_map<()[s0, s1] -> (s1)> + element_type_var: T1 + - ! + name: A + usage: input + shape: affine_map<()[s0, s1] -> (s1, s0)> + element_type_var: T2 + - ! + name: x + usage: output + shape: affine_map<()[s0, s1] -> (s0)> + element_type_var: U + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1)[s0, s1] -> (d1)> + - affine_map<(d0, d1)[s0, s1] -> (d1, d0)> + - affine_map<(d0, d1)[s0, s1] -> (d0)> + iterator_types: + - parallel + - reduction + assignments: + - !ScalarAssign + arg: x + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: x + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: y + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: dot + cpp_op_name: DotOp + doc: |- + Performs a dot product of two vectors to a scalar result. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: A + usage: input + shape: affine_map<()[s0] -> (s0)> + element_type_var: T1 + - ! + name: B + usage: input + shape: affine_map<()[s0] -> (s0)> + element_type_var: T2 + - ! + name: C + usage: output + shape: affine_map<()[s0] -> ()> + element_type_var: U + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0)[s0] -> (d0)> + - affine_map<(d0)[s0] -> (d0)> + - affine_map<(d0)[s0] -> ()> + iterator_types: + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -1,9 +1,3 @@ -ods_def -implements_interface : -def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { - C(m, n) = std_addf(C(m, n), std_mulf(A(m, k), B(k, n))); -} - ods_def implements_interface : def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) { @@ -30,12 +24,6 @@ C(m, n) = std_addi(C(m, n), std_muli(A(m, k), B(k, n))); } -ods_def -implements_interface : -def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { - x(m) = std_addf(x(m), std_mulf(A(m, n), y(n))); -} - ods_def implements_interface : def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { @@ -54,12 +42,6 @@ x(m) = std_addi(x(m), std_muli(A(m, n), y(n))); } -ods_def -implements_interface : -def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { - x(n) = std_addf(x(n), std_mulf(y(m), A(m, n))); -} - ods_def implements_interface : def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { @@ -78,12 +60,6 @@ x(n) = std_addi(x(n), std_muli(y(m), A(m, n))); } -ods_def -implements_interface : -def dot(A: f32(M), B: f32(M)) -> (C: f32()) { - C() = std_addf(C(), std_mulf(A(m), B(m))); -} - ods_def implements_interface : def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { @@ -102,12 +78,6 @@ C() = std_addi(C(), std_muli(A(m), B(m))); } -ods_def -implements_interface : -def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { - C(b, m, n) = std_addf(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n))); -} - ods_def implements_interface : def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> } @@ -16,7 +16,7 @@ // ----- func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> return %0: tensor<16x32xi32> } @@ -31,7 +31,7 @@ // ----- // Verifies floating point to integer cast. func @generalize_matmul_tensor_f32_f32_i16(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> return %0: tensor<16x32xi16> } @@ -48,7 +48,7 @@ // ----- // Verifies sign extension cast. func @generalize_matmul_tensor_i8_i8_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> return %0: tensor<16x32xi32> } @@ -65,7 +65,7 @@ // ----- // Verifies that different argument types is legal. func @generalize_matmul_tensor_i8_i16_i32(%A : tensor<16x8xi8>, %B: tensor<8x32xi16>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi16>) outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> return %0: tensor<16x32xi32> } @@ -82,7 +82,7 @@ // ----- // Somewhat non-sensical but checks integer truncation cast. func @generalize_matmul_tensor_i32_i32_i16(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi16>) -> tensor<16x32xi16> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) outs(%C: tensor<16x32xi16>) -> tensor<16x32xi16> return %0: tensor<16x32xi16> } @@ -99,7 +99,7 @@ // ----- // Verifies integer to floating point cast. func @generalize_matmul_tensor_i8_i8_f32(%A : tensor<16x8xi8>, %B: tensor<8x32xi8>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xi8>, tensor<8x32xi8>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> } @@ -116,7 +116,7 @@ // ----- // Verifies floating point extension cast. func @generalize_matmul_tensor_f16_f16_f32(%A : tensor<16x8xf16>, %B: tensor<8x32xf16>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xf16>, tensor<8x32xf16>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> } @@ -133,7 +133,7 @@ // ----- // Verifies floating point truncation. func @generalize_matmul_tensor_f64_f64_f32(%A : tensor<16x8xf64>, %B: tensor<8x32xf64>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>) + %0 = linalg.matmul ins(%A, %B: tensor<16x8xf64>, tensor<8x32xf64>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> return %0: tensor<16x32xf32> }