diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -22,13 +22,21 @@ namespace mlir { namespace linalg { +class LinalgOp; /// Returns the values obtained by applying `map` to the list of values. SmallVector applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values); +/// Checks whether `linalgOp` conforms to ContractionOpInterface. +// TODO: embed within `isa` if possible / natural. +bool isaContractionOpInterface(LinalgOp linalgOp); + namespace detail { +/// Verify that `op` conforms to ContractionOpInterface. +LogicalResult verifyContractionInterface(Operation *op); + /// Verify that `op` conforms to the invariants of StructuredOpInterface LogicalResult verifyStructuredOpInterface(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -15,8 +15,28 @@ include "mlir/IR/OpBase.td" -// The linalg 'LinalgStructuredInterface' provides access to the 'LinalgOp' -// interface. +// The 'LinalgContractionOpInterface' provides access to the +// 'ContractionOpInterface'. +def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> { + let description = [{ + A Linalg contraction is defined in general terms: + 1. Has 2 input and 1 output shapes. + 2. Has at least one reduction dimension. + 3. Has only projected permutation indexing maps. + 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field + (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary + operations that may change the type (e.g. for mixed-precision). + As a consequence, when vectorization of such an op occurs, the only special + behavior is that the (unique) MulOpType is vectorized into a + `vector.contract`. All other ops are handled in a generic fashion. + In the future, we may wish to allow more input arguments and elementwise and + constant operations that do not involve the reduction dimension(s). + }]; + let cppNamespace = "::mlir::linalg"; + let verify = [{ return detail::verifyContractionInterface($_op); }]; +} + +// The 'LinalgStructuredInterface' provides access to the 'LinalgOp' interface. def LinalgStructuredInterface : OpInterface<"LinalgOp"> { let cppNamespace = "::mlir::linalg"; let methods = [ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -1,36 +1,43 @@ -ods_def: +ods_def +implements_interface : def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } -ods_def: +ods_def +implements_interface : def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) { C(n, m) = std_addf(std_mulf(A(k, m), B(n, k))); } -ods_def: +ods_def +implements_interface : def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); } -ods_def: +ods_def +implements_interface : def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { x(m) = std_addf(std_mulf(A(m, n), y(n))); } -ods_def: +ods_def +implements_interface : def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { x(n) = std_addf(std_mulf(y(m), A(m, n))); } -ods_def: +ods_def +implements_interface : def dot(A: f32(M), B: f32(M)) -> (C: f32()) { C() = std_addf(std_mulf(A(m), B(m))); } -ods_def: +ods_def +implements_interface : def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -35,23 +35,33 @@ typename OptionsType, typename = std::enable_if_t::value>> -void sfinae_enqueue(OwningRewritePatternList &patterList, OptionsType options, +void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options, MLIRContext *context, StringRef opName, linalg::LinalgTransformationFilter m) { assert(opName == ConcreteOpType::getOperationName() && "explicit name must match ConcreteOpType::getOperationName"); - patterList.insert>(context, options, m); + patternList.insert>(context, options, m); } /// SFINAE: Enqueue helper for OpType that do not have a `getOperationName` /// (e.g. LinalgOp, other interfaces, Operation*). template