diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -232,8 +232,6 @@ return op->emitOpError("only stride-1 supported atm"); // TODO: support `getTiledImplementation` with >1 produced tiled ops. auto destOperands = op.getDestinationOperands(b); - if (destOperands.size() != 1) - return op->emitOpError("only single dest operand supported atm"); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -249,3 +249,61 @@ %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz, 20] } } + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: tile_output_multi_1d_static( +// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<100xf32> +// CHECK-SAME: %[[IN2:[0-9a-z]+]]: tensor<100xf32> +// CHECK-SAME: %[[OUT1:[0-9a-z]+]]: tensor<100xf32> +// CHECK-SAME: %[[OUT2:[0-9a-z]+]]: tensor<100xf32> + func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>, + %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>) + -> (tensor<100xf32>, tensor<100xf32>) { +// CHECK-DAG: %[[c0:.+]] = arith.constant 7 : +// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) +// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]]) +// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]]) +// CHECK-NOT: affine.min +// CHECK-NOT: affine.max +// CHECK: %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]]) +// CHECK: %[[tIN1:.+]] = tensor.extract_slice %[[IN1]][%[[LB]]] [%[[TS]]] [1] : +// CHECK: %[[tIN2:.+]] = tensor.extract_slice %[[IN2]][%[[LB]]] [%[[TS]]] [1] : +// CHECK: %[[tOUT1:.+]] = tensor.extract_slice %[[OUT1]][%[[LB]]] [%[[TS]]] [1] : +// CHECK: %[[tOUT2:.+]] = tensor.extract_slice %[[OUT2]][%[[LB]]] [%[[TS]]] [1] : +// CHECK: %[[RES1:[0-9]+]]:[[RES2:[0-9]+]] = linalg.generic +// CHECK: scf.foreach_thread.perform_concurrently +// CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#1 into %[[OUT2]][%[[LB]]] [%[[TS]]] [1] : +// CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#0 into %[[OUT1]][%[[LB]]] [%[[TS]]] [1] : + %res1, %res2 = linalg.generic + { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%IN1, %IN2 : tensor<100xf32>, tensor<100xf32>) + outs(%OUT1, %OUT2 : tensor<100xf32>, tensor<100xf32>) + { + ^bb0(%a1: f32, %a2: f32, %a3: f32, %a4: f32): + %1 = arith.addf %a1, %a3 : f32 + %2 = arith.addf %a2, %a4 : f32 + linalg.yield %1, %2 : f32,f32 + } -> (tensor<100xf32>, tensor<100xf32>) + return %res1, %res2 : tensor<100xf32>, tensor<100xf32> + } + + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [7] + } + } +