diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -995,6 +995,8 @@ return op.emitOpError("expects memref elemental types to match"); if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) return op.emitOpError("expects memref ranks to match"); + if (oType.getRank() <= 2) + return op.emitOpError("expects memref ranks to be greater than 2"); if (auto strides = op.strides()) { if (failed( verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -428,6 +428,13 @@ // ----- +func @conv_rank_limit(%arg0: memref, %arg1: memref, %arg2: memref) { + // expected-error @+1 {{expects memref ranks to be greater than 2}} + linalg.conv(%arg0, %arg1, %arg2) : memref, memref, memref +} + +// ----- + // expected-error @+1 {{unknown Linalg type}} !invalid_type = type !linalg.unknown