diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -20,6 +20,7 @@ using linalg_fill = OperationBuilder; using linalg_matmul = OperationBuilder; using linalg_matvec = OperationBuilder; +using linalg_vecmat = OperationBuilder; using linalg_range = ValueBuilder; using linalg_reshape = ValueBuilder; using linalg_slice = ValueBuilder; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -8,6 +8,11 @@ x(m) = std_addf(std_mulf(A(m, n), y(n))); } +ods_def: +def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { + x(n) = std_addf(std_mulf(y(m), A(m, n))); +} + ods_def: def dot(A: f32(M), B: f32(M)) -> (C: f32()) { C() = std_addf(std_mulf(A(m), B(m))); @@ -66,4 +71,4 @@ def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { O(n, f, d, h, w) = std_addf(std_mulf( I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); -} \ No newline at end of file +} diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -244,6 +244,7 @@ LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1350,6 +1350,7 @@ CANONICALIZERS_AND_FOLDERS(DotOp) CANONICALIZERS_AND_FOLDERS(MatmulOp) CANONICALIZERS_AND_FOLDERS(MatvecOp) +CANONICALIZERS_AND_FOLDERS(VecmatOp) CANONICALIZERS_AND_FOLDERS(ConvWOp) CANONICALIZERS_AND_FOLDERS(ConvNWCOp) CANONICALIZERS_AND_FOLDERS(ConvNCWOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -679,6 +679,8 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); if (isa(op)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -69,7 +69,7 @@ static LogicalResult isContraction(Operation *op) { // TODO: interface for named ops. if (isa(op)) + linalg::VecmatOp, linalg::DotOp>(op)) return success(); auto genericOp = dyn_cast(op); diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -449,6 +449,7 @@ patterns.insert, LinalgVectorizationPattern, LinalgVectorizationPattern, + LinalgVectorizationPattern, LinalgVectorizationPattern, LinalgVectorizationPattern>(funcOp.getContext()); applyPatternsAndFoldGreedily(funcOp, patterns);