diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -796,6 +796,9 @@ If non-empty, the `mapping` is added as an attribute to the resulting `scf.foreach_thread`. + Note: tile_sizes and num_threads handles can refer top multiple ops and are + thus unpacked dynamically. + #### Return modes This operation ignores ops that do not implement the TilingInterface and @@ -825,7 +828,6 @@ %0 = pdl_match @match_matmul in %arg1 %sz = pdl_match @match_size_op in %arg1 %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [0, %sz, 20] - ``` }]; let arguments = (ins PDL_Operation:$target, @@ -857,7 +859,7 @@ "ArrayRef":$mixedNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayAttr", "{}">:$mapping)>, + CArg<"ArrayAttr", "{}">:$mapping)> ]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" @@ -1326,16 +1327,19 @@ Value target, ArrayRef staticTileSizes, transform::TileSizesSpec, - ArrayAttr mapping) { - return build(builder, result, target, + ArrayAttr mappingAttr) { + return build(builder, result, + /*target=*/target, + /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), - TileSizesSpec(), mapping); + /*_=*/TileSizesSpec(), + /*mapping=*/mappingAttr); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, - ArrayAttr mapping) { + ArrayAttr mappingAttr) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1346,9 +1350,14 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - build(builder, result, TypeRange{operationType, operationType}, target, - /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping); + build(builder, result, + /*resultTypes=*/TypeRange{operationType, operationType}, + /*target=*/target, + /*numThreads=*/ValueRange{}, + /*tileSizes=*/dynamicTileSizes, + /*staticNumThreads=*/builder.getI64ArrayAttr({}), + /*staticTileSizes=*/staticTileSizesAttr, + /*mapping=*/mappingAttr); } void transform::TileToForeachThreadOp::build(OpBuilder &builder, @@ -1356,16 +1365,16 @@ Value target, ArrayRef staticNumThreads, transform::NumThreadsSpec, - ArrayAttr mapping) { + ArrayAttr mappingAttr) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), - NumThreadsSpec(), mapping); + NumThreadsSpec(), mappingAttr); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, - ArrayAttr mapping) { + ArrayAttr mappingAttr) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1376,11 +1385,49 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - build(builder, result, TypeRange{operationType, operationType}, target, - dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), mapping); + build(builder, result, + /*resultTypes=*/TypeRange{operationType, operationType}, + /*target=*/target, + /*numThreads=*/dynamicNumThreads, + /*tileSizes=*/ValueRange{}, + /*staticNumThreads=*/staticNumThreadsAttr, + /*staticTileSizes=*/builder.getI64ArrayAttr({}), + /*mapping=*/mappingAttr); } +// Given a list of OpFoldResults that are either index attrs or op +// handles, return a list of OpFoldResults where all op handles are +// replaced with the first (and only) OpResult of that payload op. (There +// must be exactly one mapped payload op and it must have exactly one +// index result.) +static DiagnosedSilenceableFailure unpackPDLOperations( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVector &result, ArrayRef ofrs) { + for (OpFoldResult ofr : ofrs) { + // Don't try to unpack non-PDL operation. + if (ofr.is() || + !ofr.get().getType().isa()) { + result.push_back(ofr); + continue; + } + ArrayRef dynamicNumThreads = + state.getPayloadOps(ofr.get()); + for (Operation *op : dynamicNumThreads) { + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); + } + } + + return DiagnosedSilenceableFailure(success()); +}; + DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, @@ -1390,56 +1437,18 @@ if (targets.empty()) return DiagnosedSilenceableFailure(success()); - // Given a list of OpFoldResults that are either index attrs or op handles, - // return a list of OpFoldResults where all op handles are replaced with the - // first (and only) OpResult of that payload op. (There must be exactly one - // mapped payload op and it must have exactly one index result.) - auto getOpResultsOrIndexAttrs = - [&](SmallVector &result, - ArrayRef opHandlesOrIndexAttrs) { - for (OpFoldResult ofr : opHandlesOrIndexAttrs) { - if (ofr.is()) { - result.push_back(ofr); - continue; - } - ArrayRef dynamicNumThreads = - state.getPayloadOps(ofr.get()); - if (dynamicNumThreads.size() != 1) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "handle must be mapped to exactly 1 payload op"; - diag.attachNote(ofr.get().getLoc()) - << "mapped to " << dynamicNumThreads.size() << " ops"; - return diag; - } - Operation *op = dynamicNumThreads[0]; - if (op->getNumResults() != 1 || - !op->getResult(0).getType().isIndex()) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "payload op must have exactly 1 index result"; - diag.attachNote(op->getLoc()) - << "has " << op->getNumResults() << " results"; - return diag; - } - result.push_back(op->getResult(0)); - } - - return DiagnosedSilenceableFailure(success()); - }; - // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. SmallVector numThreads; DiagnosedSilenceableFailure status = - getOpResultsOrIndexAttrs(numThreads, mixedNumThreads); + unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads); if (!status.succeeded()) return status; // getMixedTileSizes are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. SmallVector tileSizes; - status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes); + status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes); if (!status.succeeded()) return status; @@ -1488,8 +1497,11 @@ getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, tiledOps); - if (!diag.succeeded()) + if (!diag.succeeded()) { + transformResults.set(getForeachThreadOp().cast(), {}); + transformResults.set(getTiledOp().cast(), {}); return diag; + } transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); @@ -1514,8 +1526,9 @@ } LogicalResult TileToForeachThreadOp::verify() { - if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) + if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) { return emitOpError("either num_threads or tile_sizes must be specified"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -41,6 +41,48 @@ // ----- +// In this test case, matmul dims and tile size are dynamic. + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)> + +// CHECK-LABEL: matmul_tile_size_dynamic_dynamic( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_dynamic_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[tile_size_1:.*]] = "test.dummy"() + // CHECK-DAG: %[[tile_size_2:.*]] = "test.dummy"() + // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] : + // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 : + // CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size_1]]] + // CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map0]]()[%[[N]], %[[tile_size_2]]] + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: tensor.extract_slice %[[A]] + // CHECK: tensor.extract_slice %[[B]] + // CHECK: tensor.extract_slice %[[C_BLK]] + // CHECK: linalg.matmul + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + %tile_size_1 = "test.dummy"() : () -> (index) + %tile_size_2 = "test.dummy"() : () -> (index) + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %sz = transform.structured.match ops{["test.dummy"]} in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz] +} + +// ----- + // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot. // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>