diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -8,6 +8,13 @@ C(n, m) = std_addf(std_mulf(A(k, m), B(n, k))); } +ods_def: +def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { + // TODO: ideally something closer to + // C(m, n) += cast(A(m, k)) * cast(B(k, n)) + C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); +} + ods_def: def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { x(m) = std_addf(std_mulf(A(m, n), y(n))); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -21,27 +21,63 @@ /// Abstract Transformation class applied in a sequence that also handles state /// through markers. struct Transformation { + Transformation() {} + Transformation(linalg::LinalgTransformationFilter::FilterFunction f) + : filter(f) {} virtual ~Transformation() = default; virtual OwningRewritePatternList - buildRewritePatterns(MLIRContext *context, linalg::LinalgMarker m) = 0; - linalg::LinalgMarker marker; + buildRewritePatterns(MLIRContext *context, + linalg::LinalgTransformationFilter m) = 0; + linalg::LinalgTransformationFilter::FilterFunction filter = nullptr; }; +/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`. +template