diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -3067,9 +3067,8 @@ static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl &resultTypes) { - if (succeeded(parser.parseOptionalArrow())) - if (parser.parseTypeList(resultTypes)) - return failure(); + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); return success(); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -85,7 +85,7 @@ ^bb0(%gen_arg1: f32, %out1: f32, %out2: f32): %tmp1 = math.exp %gen_arg1 : f32 linalg.yield %tmp1, %tmp1 : f32, f32 - } -> tensor<4xf32>, tensor<4xf32> + } -> (tensor<4xf32>, tensor<4xf32>) return %0, %1 : tensor<4xf32>, tensor<4xf32> } @@ -118,7 +118,7 @@ ^bb0(%gen_arg1: f32, %out1: f32, %out2: f32): %tmp1 = math.exp %gen_arg1 : f32 linalg.yield %tmp1, %tmp1 : f32, f32 - } -> tensor, tensor + } -> (tensor, tensor) return %0, %1 : tensor, tensor } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -714,7 +714,7 @@ outs(%arg_0, %arg_1 : tensor, tensor) { ^bb0(%in: f32, %out_0: f32, %out_1: f32): linalg.yield %in, %in : f32, f32 - } -> tensor, tensor + } -> (tensor, tensor) %c0 = constant 0 : index %num_elem_0 = memref.dim %0, %c0 : tensor @@ -778,7 +778,7 @@ outs(%3, %3 : tensor, tensor) { ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32, %arg5 : f32): linalg.yield %arg3, %arg2 : f32, f32 - } -> tensor, tensor + } -> (tensor, tensor) return %4, %5 : tensor, tensor } // CHECK-LABEL: func @remove_no_op @@ -832,7 +832,7 @@ outs(%2, %2 : tensor, tensor) { ^bb0(%arg3: f32, %arg4 : f32, %arg5 : f32, %arg6 : f32): linalg.yield %arg2, %arg4 : f32, f32 - } -> tensor, tensor + } -> (tensor, tensor) return %3#0, %3#1 : tensor, tensor } // CHECK-LABEL: func @keep_not_noop diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -449,7 +449,7 @@ func @incorrect_region_arg_count(%m: memref) { // expected-error @+3 {{region expects 3 args, got 2}} %res = linalg.matmul ins(%m, %m : memref, memref) - -> tensor, tensor + -> (tensor, tensor) return } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -424,6 +424,39 @@ // ----- +func @generic_with_multiple_tensor_outputs( + %arg0: tensor, %arg1: tensor, %arg2: i32) + -> (tensor, tensor) { + %c0 = constant 0 : index + %0 = linalg.init_tensor [] : tensor + %1 = linalg.fill(%0, %arg2) : tensor, i32 -> tensor + %2 = linalg.init_tensor [] : tensor + %3 = linalg.fill(%2, %arg2) : tensor, i32 -> tensor + %4:2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%1, %3 : tensor, tensor) { + ^bb0(%arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32): // no predecessors + %5 = cmpi sge, %arg3, %arg5 : i32 + %6 = select %5, %arg3, %arg5 : i32 + %7 = cmpi eq, %arg3, %arg5 : i32 + %8 = cmpi slt, %arg4, %arg6 : i32 + %9 = select %8, %arg4, %arg6 : i32 + %10 = select %5, %arg4, %arg6 : i32 + %11 = select %7, %9, %10 : i32 + linalg.yield %6, %11 : i32, i32 + } -> (tensor, tensor) + return %4#0, %4#1 : tensor, tensor +} +// CHECK-LABEL: func @generic_with_multiple_tensor_outputs +// CHECK: %{{.*}} = linalg.generic { +// CHECK-SAME: ins({{.*}} : tensor, tensor) +// CHECK-SAME: outs({{.*}} : tensor, tensor) +// CHECK: } -> (tensor, tensor) + +// ----- + #accesses_2 = [ affine_map<(i, j, k) -> (j, i)>, affine_map<(i, j, k) -> (i, k, i + j)>, diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -386,9 +386,9 @@ // CHECK: %[[R9:.*]] = vector.transfer_write %[[TAN]], %[[ARG0]][%[[C0]], %[[C0]]] {{.*}} : vector<4x256xf32>, tensor<4x256xf32> linalg.yield %6, %8, %c1_f32, %9, %10, %11, %12, %13, %14, %15 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32 - } -> tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, + } -> (tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, - tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> + tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>) // CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]], %[[R4]], %[[R5]], %[[R6]], %[[R7]], %[[R8]], %[[R9]] : tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> return %r#0, %r#1, %r#2, %r#3, %r#4, %r#5, %r#6, %r#7, %r#8, %r#9: tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>,