diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -12,3 +12,39 @@ def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); } + +ods_def: +def conv_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) { + O(n, w, f) = std_addf(O(n, w, f), + std_mulf(I(n, w + kw, c), K(f, kw, c))); +} + +ods_def: +def conv_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) { + O(n, f, w) = std_addf(O(n, f, w), + std_mulf(I(n, c, w + kw), K(f, c, kw))); +} + +ods_def: +def conv_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) { + O(n, h, w, f) = std_addf(O(n, h, w, f), + std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); +} + +ods_def: +def conv_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { + O(n, f, h, w) = std_addf(O(n, f, h, w), + std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); +} + +ods_def: +def conv_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) { + O(n, d, h, w, f) = std_addf(O(n, d, h, w, f), + std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); +} + +ods_def: +def conv_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { + O(n, f, d, h, w) = std_addf(O(n, f, d, h, w), + std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1351,3 +1351,27 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult ConvNWCOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNCWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNHWCOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNCHWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNDHWCOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} +LogicalResult ConvNCDHWOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -744,6 +744,18 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); }