diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -573,6 +573,15 @@ << ") to be equal to the number of output tensors (" << linalgOp.getOutputTensorOperands().size() << ")"; + // Check all iterator types are known. + auto iteratorTypesRange = + linalgOp.iterator_types().getAsValueRange(); + for (StringRef iteratorType : iteratorTypesRange) { + if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType)) + return op->emitOpError("unexpected iterator_type (") + << iteratorType << ")"; + } + // Before checking indexing maps, we need to make sure the attributes // referenced by it are valid. if (linalgOp.hasDynamicIndexingMaps()) diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -95,6 +95,19 @@ // ----- +func @generic_wrong_iterator(%arg0: memref<1xi32>) { + // expected-error @+1 {{op unexpected iterator_type (random)}} + linalg.generic { + indexing_maps = [ affine_map<(i) -> (i)> ], + iterator_types = ["random"]} + outs(%arg0 : memref<1xi32>) { + ^bb(%i : i32): + linalg.yield %i : i32 + } +} + +// ----- + func @generic_one_d_view(%arg0: memref(off + i)>>) { // expected-error @+1 {{expected operand rank (1) to match the result rank of indexing_map #0 (2)}} linalg.generic {