diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -1,36 +1,43 @@ -ods_def: +ods_def +implements_interface : def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } -ods_def: +ods_def +implements_interface : def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) { C(n, m) = std_addf(std_mulf(A(k, m), B(n, k))); } -ods_def: +ods_def +implements_interface : def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); } -ods_def: +ods_def +implements_interface : def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { x(m) = std_addf(std_mulf(A(m, n), y(n))); } -ods_def: +ods_def +implements_interface : def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { x(n) = std_addf(std_mulf(y(m), A(m, n))); } -ods_def: +ods_def +implements_interface : def dot(A: f32(M), B: f32(M)) -> (C: f32()) { C() = std_addf(std_mulf(A(m), B(m))); } -ods_def: +ods_def +implements_interface : def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -35,23 +35,33 @@ typename OptionsType, typename = std::enable_if_t::value>> -void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options, +void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options, MLIRContext *context, StringRef opName, linalg::LinalgTransformationFilter m) { assert(opName == ConcreteOpType::getOperationName() && "explicit name must match ConcreteOpType::getOperationName"); - patterList.insert>(context, options, m); + patternList.insert>(context, options, m); } /// SFINAE: Enqueue helper for OpType that do not have a `getOperationName` /// (e.g. LinalgOp, other interfaces, Operation*). template