diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -235,7 +235,7 @@ auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; if (llvm::any_of(loopRanges, hasStrideOne)) return op->emitOpError("only stride-1 supported atm"); - auto destOperands = op.getDestinationOperands(b); + auto dest = op.getDestinationOperands(b); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { @@ -348,7 +348,7 @@ resultSizes))) return op->emitOpError("output offsets couldn't be calculated"); SmallVector strides(resultSizes.size(), b.getIndexAttr(1)); - b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody()); + b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody()); b.create(loc, std::get<1>(it), std::get<2>(it), resultOffsets, resultSizes, strides); diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -260,13 +260,13 @@ // CHECK-LABEL: tile_output_multi_1d_static( // CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<100xf32> // CHECK-SAME: %[[IN2:[0-9a-z]+]]: tensor<100xf32> -// CHECK-SAME: %[[OUT1:[0-9a-z]+]]: tensor<100xf32> -// CHECK-SAME: %[[OUT2:[0-9a-z]+]]: tensor<100xf32> +// CHECK-SAME: %[[ORGOUT1:[0-9a-z]+]]: tensor<100xf32> +// CHECK-SAME: %[[ORGOUT2:[0-9a-z]+]]: tensor<100xf32> func.func @tile_output_multi_1d_static(%IN1: tensor<100xf32>, %IN2: tensor<100xf32>, %OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>) -> (tensor<100xf32>, tensor<100xf32>) { // CHECK-DAG: %[[c0:.+]] = arith.constant 7 : -// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) +// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]]) // CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]]) // CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]]) // CHECK-NOT: affine.min @@ -278,8 +278,8 @@ // CHECK: %[[tOUT2:.+]] = tensor.extract_slice %[[OUT2]][%[[LB]]] [%[[TS]]] [1] : // CHECK: %[[RES1:[0-9]+]]:[[RES2:[0-9]+]] = linalg.generic // CHECK: scf.foreach_thread.perform_concurrently -// CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#1 into %[[OUT2]][%[[LB]]] [%[[TS]]] [1] : // CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#0 into %[[OUT1]][%[[LB]]] [%[[TS]]] [1] : +// CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#1 into %[[OUT2]][%[[LB]]] [%[[TS]]] [1] : %res1, %res2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -307,3 +307,62 @@ } } +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 75)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0, d1) -> (d1, d0) +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: tile_output_multi_1d2d_static( +// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<100xf32> +// CHECK-SAME: %[[IN2:[0-9a-z]+]]: tensor<100x300xf32> +// CHECK-SAME: %[[IN3:[0-9a-z]+]]: tensor<300xf32> +// CHECK-SAME: %[[ORGOUT1:[0-9a-z]+]]: tensor<300x100xf32> +// CHECK-SAME: %[[ORGOUT2:[0-9a-z]+]]: tensor<300xf32> + func.func @tile_output_multi_1d2d_static(%IN1: tensor<100xf32>, %IN2: tensor<100x300xf32>, %IN3: tensor<300xf32>, + %OUT1: tensor<300x100xf32>, %OUT2: tensor<300xf32>) + -> (tensor<300x100xf32>, tensor<300xf32>) { +// CHECK-DAG: %[[c0:.+]] = arith.constant 4 : +// CHECK: scf.foreach_thread (%[[IV0:.+]]) in (%[[c0]]) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]]) +// CHECK: %[[LB:.+]] = affine.apply #[[$map0]](%[[IV0]]) +// CHECK: %[[tIN1:.+]] = tensor.extract_slice %[[IN2]][0, %[[LB]]] [100, 75] +// CHECK: %[[tIN2:.+]] = tensor.extract_slice %[[IN3]][%[[LB]]] [75] +// CHECK: %[[tOUT1:.+]] = tensor.extract_slice %[[OUT1]][%[[LB]], 0] [75, 100] +// CHECK: %[[tOUT2:.+]] = tensor.extract_slice %[[OUT2]][%[[LB]]] [75] +// CHECK: %[[RES1:[0-9]+]]:[[RES2:[0-9]+]] = linalg.generic +// CHECK: scf.foreach_thread.perform_concurrently +// CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#0 into %[[OUT1]][%[[LB]], 0] [75, 100] +// CHECK-NEXT: tensor.parallel_insert_slice %[[RES1]]#1 into %[[OUT2]][%[[LB]]] [75] + %res2, %res3 = linalg.generic { + indexing_maps = [affine_map<(d0,d1) -> (d1)>, + affine_map<(d0,d1) -> (d1,d0)>, + affine_map<(d0,d1) -> (d0)>, + affine_map<(d0,d1) -> (d0,d1)>, + affine_map<(d0,d1) -> (d0)> + ], + iterator_types = ["parallel", "parallel"] + } ins(%IN1, %IN2, %IN3 : tensor<100xf32>, tensor<100x300xf32>, tensor<300xf32>) + outs(%OUT1, %OUT2: tensor<300x100xf32>, tensor<300xf32>) { + ^bb0(%i1: f32, %i2: f32, %i3: f32, %o1: f32, %o2: f32): + %1 = arith.addf %i1, %o1 : f32 + %2 = arith.addf %i2, %1 : f32 + %3 = arith.addf %i3, %2 : f32 + linalg.yield %3, %i3 : f32, f32 + } -> (tensor<300x100xf32>, tensor<300xf32>) + + return %res2, %res3 : tensor<300x100xf32>, tensor<300xf32> + } + + transform.with_pdl_patterns { + ^bb0(%IN_MAT1: !pdl.operation): + transform.sequence %IN_MAT1 failures(propagate) { + ^bb1(%IN_MAT2: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %IN_MAT2 + %foreach_thread, %tiled_generic = transform.structured.tile_to_foreach_thread_op %0 num_threads [4] + } + } + + +