diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -796,6 +796,10 @@ If non-empty, the `mapping` is added as an attribute to the resulting `scf.foreach_thread`. + A `packed_sizes` unit attribute may be specified that controls whether the + `num_threads` or `tile_sizes` specification refers to a handle that contains + multiple PDL_Operation values that are unpacked dynamically. + #### Return modes This operation ignores ops that do not implement the TilingInterface and @@ -825,6 +829,13 @@ %0 = pdl_match @match_matmul in %arg1 %sz = pdl_match @match_size_op in %arg1 %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [0, %sz, 20] + + #### Example using `tile_sizes` and `packed_sizes` + + ``` + %0 = pdl_match @match_matmul in %arg1 + %sz = pdl_match @match_multiple_size_ops in %arg1 + %3:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz] {packed_sizes} ``` }]; @@ -848,6 +859,11 @@ CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, CArg<"ArrayAttr", "{}">:$mapping)>, + OpBuilder<(ins "Value":$target, + "Value":$packedTileSizesHandle, + CArg<"::mlir::transform::TileSizesSpec", + "::mlir::transform::TileSizesSpec()">, + CArg<"ArrayAttr", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$staticNumThreads, CArg<"::mlir::transform::NumThreadsSpec", @@ -858,6 +874,11 @@ CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, CArg<"ArrayAttr", "{}">:$mapping)>, + OpBuilder<(ins "Value":$target, + "Value":$packedNumThreadsHandle, + CArg<"::mlir::transform::NumThreadsSpec", + "::mlir::transform::NumThreadsSpec()">, + CArg<"ArrayAttr", "{}">:$mapping)> ]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" @@ -1326,16 +1327,19 @@ Value target, ArrayRef staticTileSizes, transform::TileSizesSpec, - ArrayAttr mapping) { - return build(builder, result, target, + ArrayAttr mappingAttr) { + return build(builder, result, + /*target=*/target, + /*mixedTileSizes=*/ getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), - TileSizesSpec(), mapping); + /*_=*/TileSizesSpec(), + /*mapping=*/mappingAttr); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, - ArrayAttr mapping) { + ArrayAttr mappingAttr) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1346,9 +1350,37 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - build(builder, result, TypeRange{operationType, operationType}, target, - /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mapping); + build(builder, result, + /*resultTypes=*/TypeRange{operationType, operationType}, + /*target=*/target, + /*numThreads=*/ValueRange{}, + /*tileSizes=*/dynamicTileSizes, + /*staticNumThreads=*/builder.getI64ArrayAttr({}), + /*staticTileSizes=*/staticTileSizesAttr, + /*mapping=*/mappingAttr); +} + +void transform::TileToForeachThreadOp::build(OpBuilder &builder, + OperationState &result, + Value target, + Value packedTileSizesHandle, + transform::TileSizesSpec, + ArrayAttr mappingAttr) { + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, horrible + // bugs ensue. + MLIRContext *ctx = builder.getContext(); + auto operationType = pdl::OperationType::get(ctx); + build(builder, result, + /*resultTypes=*/TypeRange{operationType, operationType}, + /*target=*/target, + /*numThreads=*/ValueRange{}, + /*tileSizes=*/ValueRange{packedTileSizesHandle}, + /*staticNumThreads=*/builder.getI64ArrayAttr({}), + /*staticTileSizes=*/ + builder.getI64ArrayAttr(ShapedType::kDynamicSize), + /*mapping=*/mappingAttr, + /*packedSizes=*/builder.getUnitAttr()); } void transform::TileToForeachThreadOp::build(OpBuilder &builder, @@ -1356,16 +1388,16 @@ Value target, ArrayRef staticNumThreads, transform::NumThreadsSpec, - ArrayAttr mapping) { + ArrayAttr mappingAttr) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), - NumThreadsSpec(), mapping); + NumThreadsSpec(), mappingAttr); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, - ArrayAttr mapping) { + ArrayAttr mappingAttr) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1376,11 +1408,72 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - build(builder, result, TypeRange{operationType, operationType}, target, - dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), mapping); + build(builder, result, + /*resultTypes=*/TypeRange{operationType, operationType}, + /*target=*/target, + /*numThreads=*/dynamicNumThreads, + /*tileSizes=*/ValueRange{}, + /*staticNumThreads=*/staticNumThreadsAttr, + /*staticTileSizes=*/builder.getI64ArrayAttr({}), + /*mapping=*/mappingAttr); } +void transform::TileToForeachThreadOp::build(OpBuilder &builder, + OperationState &result, + Value target, + Value packedNumThreadsHandle, + transform::NumThreadsSpec, + ArrayAttr mappingAttr) { + // Call the default builder which sets up the proper operands segment sizes + // attributes for multiple variadic operands. In the absence of this, horrible + // bugs ensue. + MLIRContext *ctx = builder.getContext(); + auto operationType = pdl::OperationType::get(ctx); + build(builder, result, + /*resultTypes=*/TypeRange{operationType, operationType}, + /*target=*/target, + /*numThreads=*/ValueRange{packedNumThreadsHandle}, + /*tileSizes=*/ValueRange{}, + /*staticNumThreads=*/ + builder.getI64ArrayAttr(ShapedType::kDynamicSize), + /*staticTileSizes=*/builder.getI64ArrayAttr({}), + /*mapping=*/mappingAttr, + /*packedSizes=*/builder.getUnitAttr()); +} + +// Given a list of OpFoldResults that are either index attrs or op +// handles, return a list of OpFoldResults where all op handles are +// replaced with the first (and only) OpResult of that payload op. (There +// must be exactly one mapped payload op and it must have exactly one +// index result.) +static DiagnosedSilenceableFailure unpackPDLOperations( + transform::TransformState &state, TransformOpInterface transformOp, + SmallVector &result, ArrayRef ofrs) { + for (OpFoldResult ofr : ofrs) { + // Don't try to unpack non-PDL operation. + if (ofr.is() || + !ofr.get().getType().isa()) { + result.push_back(ofr); + continue; + } + ArrayRef dynamicNumThreads = + state.getPayloadOps(ofr.get()); + for (Operation *op : dynamicNumThreads) { + if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { + DiagnosedSilenceableFailure diag = + transformOp.emitSilenceableError() + << "payload op must have exactly 1 index result"; + diag.attachNote(op->getLoc()) + << "has " << op->getNumResults() << " results"; + return diag; + } + result.push_back(op->getResult(0)); + } + } + + return DiagnosedSilenceableFailure(success()); +}; + DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, @@ -1390,56 +1483,18 @@ if (targets.empty()) return DiagnosedSilenceableFailure(success()); - // Given a list of OpFoldResults that are either index attrs or op handles, - // return a list of OpFoldResults where all op handles are replaced with the - // first (and only) OpResult of that payload op. (There must be exactly one - // mapped payload op and it must have exactly one index result.) - auto getOpResultsOrIndexAttrs = - [&](SmallVector &result, - ArrayRef opHandlesOrIndexAttrs) { - for (OpFoldResult ofr : opHandlesOrIndexAttrs) { - if (ofr.is()) { - result.push_back(ofr); - continue; - } - ArrayRef dynamicNumThreads = - state.getPayloadOps(ofr.get()); - if (dynamicNumThreads.size() != 1) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "handle must be mapped to exactly 1 payload op"; - diag.attachNote(ofr.get().getLoc()) - << "mapped to " << dynamicNumThreads.size() << " ops"; - return diag; - } - Operation *op = dynamicNumThreads[0]; - if (op->getNumResults() != 1 || - !op->getResult(0).getType().isIndex()) { - DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "payload op must have exactly 1 index result"; - diag.attachNote(op->getLoc()) - << "has " << op->getNumResults() << " results"; - return diag; - } - result.push_back(op->getResult(0)); - } - - return DiagnosedSilenceableFailure(success()); - }; - // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. SmallVector numThreads; DiagnosedSilenceableFailure status = - getOpResultsOrIndexAttrs(numThreads, mixedNumThreads); + unpackPDLOperations(state, transformOp, numThreads, mixedNumThreads); if (!status.succeeded()) return status; // getMixedTileSizes are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. SmallVector tileSizes; - status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes); + status = unpackPDLOperations(state, transformOp, tileSizes, mixedTileSizes); if (!status.succeeded()) return status; @@ -1488,8 +1543,11 @@ getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, tiledOps); - if (!diag.succeeded()) + if (!diag.succeeded()) { + transformResults.set(getForeachThreadOp().cast(), {}); + transformResults.set(getTiledOp().cast(), {}); return diag; + } transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); @@ -1514,8 +1572,9 @@ } LogicalResult TileToForeachThreadOp::verify() { - if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) + if (getMixedNumThreads().empty() == getMixedTileSizes().empty()) { return emitOpError("either num_threads or tile_sizes must be specified"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -41,6 +41,48 @@ // ----- +// In this test case, matmul dims and tile size are dynamic. + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> +// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)> + +// CHECK-LABEL: matmul_tile_size_dynamic_dynamic( +// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor +func.func @matmul_tile_size_dynamic_dynamic(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[tile_size_1:.*]] = "test.dummy"() + // CHECK-DAG: %[[tile_size_2:.*]] = "test.dummy"() + // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] : + // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %c1 : + // CHECK-DAG: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]], %[[tile_size_1]]] + // CHECK-DAG: %[[NT1:.+]] = affine.apply #[[$map0]]()[%[[N]], %[[tile_size_2]]] + // CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]]) + // CHECK: tensor.extract_slice %[[A]] + // CHECK: tensor.extract_slice %[[B]] + // CHECK: tensor.extract_slice %[[C_BLK]] + // CHECK: linalg.matmul + // CHECK: scf.foreach_thread.perform_concurrently + // CHECK-NEXT: tensor.parallel_insert_slice + %tile_size_1 = "test.dummy"() : () -> (index) + %tile_size_2 = "test.dummy"() : () -> (index) + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %sz = transform.structured.match ops{["test.dummy"]} in %arg1 + %1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [%sz] +} + +// ----- + // Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot. // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>