diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1775,6 +1775,27 @@ } if (auto tiledLoopOp = dyn_cast(parentOp)) { + // Check if output args with tensor types match results types. + SmallVector tensorOuts; + llvm::copy_if( + tiledLoopOp.outputs(), std::back_inserter(tensorOuts), + [&](Value out) { return out.getType().isa(); }); + if (tensorOuts.size() != op.values().size()) + return op.emitOpError("expected number of tensor output args = ") + << tensorOuts.size() << " to match the number of yield operands = " + << op.values().size(); + + TypeRange tensorTypes(llvm::makeArrayRef(tensorOuts)); + for (auto &item : + llvm::enumerate(llvm::zip(tensorTypes, op.getOperandTypes()))) { + Type outType, resultType; + unsigned index = item.index(); + std::tie(outType, resultType) = item.value(); + if (outType != resultType) + return op.emitOpError("expected yield operand ") + << index << " with type = " << resultType + << " to match output arg type = " << outType; + } return success(); } return op.emitOpError("expected parent op with LinalgOp interface"); @@ -1964,7 +1985,14 @@ return !region().isAncestor(value.getParentRegion()); } -static LogicalResult verify(TiledLoopOp op) { return success(); } +static LogicalResult verify(TiledLoopOp op) { + // Check if iterator types are provided for every loop dimension. + if (op.iterator_types().size() != op.getNumLoops()) + return op.emitOpError("expected iterator types array attribute size = ") + << op.iterator_types().size() + << " to match the number of loops = " << op.getNumLoops(); + return success(); +} namespace { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -758,3 +758,87 @@ outs(%output : memref<1x2x3x1xf32>) return } + +// ----- + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, + %C: memref<192x192xf32>) -> () + +func @tiled_loop_incorrent_num_yield_operands(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, %C: memref<192x192xf32>, + %C_tensor: tensor<192x192xf32>) { + %c24 = constant 24 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) + step (%c24, %c24) + ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>) + outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) { + call @foo(%A, %B, %C) + : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () + // expected-error @+1 {{expected number of tensor output args = 1 to match the number of yield operands = 0}} + linalg.yield + } + return +} + +// ----- + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, + %C: memref<192x192xf32>) -> tensor + +func @tiled_loop_incorrent_yield_operand_type(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, %C: memref<192x192xf32>, + %C_tensor: tensor<192x192xf32>) { + %c24 = constant 24 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c192, %c192) + step (%c24, %c24) + ins (%A, %B: memref<192x192xf32>, memref<192x192xf32>) + outs (%C_tensor, %C :tensor<192x192xf32>, memref<192x192xf32>) { + %1 = call @foo(%A, %B, %C) + : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> tensor + // expected-error @+1 {{expected yield operand 0 with type = 'tensor' to match output arg type = 'tensor<192x192xf32>}} + linalg.yield %1 : tensor + } + return +} + +// ----- + +#map0 = affine_map<(d0) -> (24, -d0 + 192)> +#map1 = affine_map<(d0, d1)[s0] -> (d0 * 192 + s0 + d1)> +#map2 = affine_map<(d0) -> (16, -d0 + 192)> + +func private @foo(%A: memref<192x192xf32>, %B: memref<192x192xf32>, + %C: memref<192x192xf32>) -> () + +func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>, + %B: memref<192x192xf32>, %C: memref<192x192xf32>, + %C_tensor: tensor<192x192xf32>) { + %c24 = constant 24 : index + %c0 = constant 0 : index + %c192 = constant 192 : index + // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}} + %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( { + ^bb0(%arg4: index, %arg5: index): // no predecessors + call @foo(%A, %B, %C) + : (memref<192x192xf32>, memref<192x192xf32>, memref<192x192xf32>)-> () + linalg.yield %C_tensor : tensor<192x192xf32> + }) { + iterator_types = ["parallel"], + operand_segment_sizes = dense<2> : vector<5xi32> + } : (index, index, index, index, index, index, memref<192x192xf32>, + memref<192x192xf32>, tensor<192x192xf32>, memref<192x192xf32> + ) -> tensor<192x192xf32> + return +}