diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -660,25 +660,31 @@ def TileToForeachThreadOp : Op, TransformOpInterface]> { let description = [{ - Tile a TilingInterface op to a tiled `scf.foreach_thread`. Tiling is - applied by either specifying `num_threads` or `tile_size`. If `num_threads` - is specified, then the tile size for each dimension `i` is calculated - dynamically via `ceilDiv(dimSize[i], num_threads[i])`. - If non-empty, the `thread_dim_mapping` is added as an attribute to the - resulting `scf.foreach_thread`. - Zero tile sizes indicate that the dimension is not tiled and can be + Tile a TilingInterface op to a tiled `scf.foreach_thread`. + + Tiling is applied by either specifying `num_threads` or `tile_size`. If + `num_threads` is specified, then the tile size for each dimension `i` is + calculated dynamically via `ceilDiv(dimSize[i], num_threads[i])`. + `num_threads` and `tile_size` can be either static index attributes or SSA + values of PDL operation handle type (or a mix thereof). Operation handles + must be mapped to exactly one op that has exactly one result of index type. + + Static zero tile sizes indicate that the dimension is not tiled and can be thought of as tiling by the full size of data. + It is the user's responsibility to ensure that `num_threads/tile_sizes` is a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case). - + + If non-empty, the `thread_dim_mapping` is added as an attribute to the + resulting `scf.foreach_thread`. + #### Return modes - + This operation ignores ops that do not implement the TilingInterface and drops them in the return. @@ -696,36 +702,46 @@ ### Example using `num_threads` ``` - %0 = pdl_match @match_matmul in %arg1 + %0 = pdl_match @match_matmul in %arg1 %3:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] ``` ### Example using `tile_sizes` - + ``` - %0 = pdl_match @match_matmul in %arg1 - %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 20, 0] + %0 = pdl_match @match_matmul in %arg1 + %sz = pdl_match @match_size_op in %arg1 + %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [0, %sz, 20] ``` }]; let arguments = (ins PDL_Operation:$target, - // TODO: dynamic number of threads. - OptionalAttr>:$num_threads, - OptionalAttr>:$tile_sizes, + Variadic:$num_threads, + Variadic:$tile_sizes, + DefaultValuedAttr:$static_num_threads, + DefaultValuedAttr:$static_tile_sizes, OptionalAttr:$thread_dim_mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); - let assemblyFormat = [{ - $target (`num_threads` $num_threads^) : (`tile_sizes` $tile_sizes)? - (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict + $target oilist( + `num_threads` custom($num_threads, + $static_num_threads, + "ShapedType::kDynamicSize") | + `tile_sizes` custom($tile_sizes, + $static_tile_sizes, + "ShapedType::kDynamicSize")) + (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict }]; + let hasVerifier = 1; let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::TilingInterface target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); + + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedNumThreads(); + ::llvm::SmallVector<::mlir::OpFoldResult> getMixedTileSizes(); }]; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1020,33 +1020,137 @@ // TileToForeachThreadOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne( - TilingInterface target, SmallVectorImpl &results, +DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( + transform::TransformResults &transformResults, transform::TransformState &state) { IRRewriter rewriter(getContext()); - rewriter.setInsertionPoint(target); - auto maybeThreadDimMappingAttr = getThreadDimMapping(); - auto dimMapping = - llvm::to_vector(maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); - - FailureOr tilingResult = failure(); - if (Optional numThreads = getNumThreads()) - tilingResult = linalg::tileToForeachThreadOp( - rewriter, target, getAsOpFoldResult(*numThreads), dimMapping); - - if (Optional tileSizes = getTileSizes()) - tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping); - - if (failed(tilingResult)) - return emitDefaultSilenceableFailure(target); - rewriter.replaceOp(target, tilingResult->tileOp->getResults()); - results.assign({tilingResult->tileOp, tilingResult->tiledOp}); + ArrayRef targets = state.getPayloadOps(getTarget()); + + // If there the target payload ops are empty, there is nothing to do. + if (targets.empty()) { + transformResults.set(getForeachThreadOp().cast(), {}); + transformResults.set(getTiledOp().cast(), {}); + return DiagnosedSilenceableFailure(success()); + } + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Given a list of OpFoldResults that are either index attrs or op handles, + // return a list of OpFoldResults where all op handles are replaced with the + // first (and only) OpResult of that payload op. (There must be exactly one + // mapped payload op and it must have exactly one index result.) + auto getOpResultsOrIndexAttrs = + [&](SmallVector &result, + ArrayRef opHandlesOrIndexAttrs) { + for (OpFoldResult ofr : opHandlesOrIndexAttrs) { + if (ofr.is()) { + result.push_back(ofr); + continue; + } + ArrayRef dynamicNumThreads = + state.getPayloadOps(ofr.get()); + if (dynamicNumThreads.size() != 1) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "handle must be mapped to exactly 1 payload op"; + diag.attachNote(ofr.get().getLoc()) + << "mapped to " << dynamicNumThreads.size() << " ops"; + return diag; + } + Operation *op = dynamicNumThreads[0]; + if (op->getNumResults() != 1 || + !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); + } + + return DiagnosedSilenceableFailure(success()); + }; + + // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. + // Convert to OpFoldResults[index attributes or payload op]. + SmallVector numThreads; + DiagnosedSilenceableFailure status = + getOpResultsOrIndexAttrs(numThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + + // getMixedTileSizes are OpFoldResults[index attributes or PDL operation]. + // Convert to OpFoldResults[index attributes or payload op]. + SmallVector tileSizes; + status = getOpResultsOrIndexAttrs(tileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + // Transform all targets one by one. + for (Operation *target : targets) { + auto tilableOp = dyn_cast(target); + if (!tilableOp) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only TilingInterface ops are supported"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + rewriter.setInsertionPoint(tilableOp); + auto maybeThreadDimMappingAttr = getThreadDimMapping(); + auto dimMapping = llvm::to_vector( + maybeThreadDimMappingAttr + ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) + : ArrayRef{}); + + FailureOr tilingResult = failure(); + if (!getMixedNumThreads().empty()) { + tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, + numThreads, dimMapping); + } else { + tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( + rewriter, tilableOp, tileSizes, dimMapping); + } + + if (failed(tilingResult)) + return emitDefaultSilenceableFailure(tilableOp); + rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); + + tileOps.push_back(tilingResult->tileOp); + tiledOps.push_back(tilingResult->tiledOp); + } + + transformResults.set(getForeachThreadOp().cast(), tileOps); + transformResults.set(getTiledOp().cast(), tiledOps); + return DiagnosedSilenceableFailure(success()); } +void transform::TileToForeachThreadOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + onlyReadsHandle(getTileSizes(), effects); + onlyReadsHandle(getNumThreads(), effects); + producesHandle(getResults(), effects); +} + +SmallVector TileToForeachThreadOp::getMixedNumThreads() { + return getMixedSizes(getStaticNumThreads(), getNumThreads()); +} + +SmallVector TileToForeachThreadOp::getMixedTileSizes() { + return getMixedSizes(getStaticTileSizes(), getTileSizes()); +} + +LogicalResult TileToForeachThreadOp::verify() { + if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) + return emitOpError("either num_threads or tile_sizes must be specified"); + return success(); +} + //===----------------------------------------------------------------------===// // VectorizeOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -203,3 +203,49 @@ // CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]]) // CHECK: scf.foreach_thread.perform_concurrently { // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32> + +// ----- + +// In this test case, matmul dims and tile size are dynamic. + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> +// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)> +// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)> + +// CHECK-LABEL: matmul_tile_size_dynamic_dynamic( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_dynamic_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[tile_size:.*]] = "test.dummy"() + // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] : + // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 : + // CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size]]] + // CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]] + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) + // CHECK tensor.extract_slice %[[A]] + // CHECK tensor.extract_slice %[[B]] + // CHECK tensor.extract_slice %[[C]] + // CHECK: linalg.matmul + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + %tile_size = "test.dummy"() : () -> (index) + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %sz = transform.structured.match ops{["test.dummy"]} in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz, 20] + } +}