diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2338,16 +2338,17 @@ def MapCopyToThreadsOp : Op { let description = [{ - Targeted mapping of a copy operation on tensors to a GPU thread mapping. + Targeted mapping of a linalg.copy / tensor.pad operation on tensors to a GPU + thread mapping. - This operation implements a greedy heuristic that determines a good - distribution of threads to break down the copy operation into. - The heuristic is driven by considerations related to the underlying + This operation implements a greedy heuristic that determines a good + distribution of threads to break down the copy/pad operation into. + The heuristic is driven by considerations related to the underlying architecture for which good high-level decisions are needed assuming certain hardware features. Relevant features are exposed via first-class attributes to control the behavior of the transformation at a high level. @@ -2355,22 +2356,25 @@ For now, a single heuristic is implemented and can be extended on a per-need basis. - #### Return modes: + #### Return modes - The operation always succeeds and returns a handle to the relevant tiled - linalg.copy op. + This operation fails definitely if there is an unsupported op (i.e., not + linalg.copy / tensor.pad) among the targeted op. Otherwise, the operation + always succeeds and returns a handle to the relevant tiled linalg.copy / + tensor.pad op and the enclosing scf.forall op. }]; let arguments = (ins TransformHandleTypeInterface:$target, I64Attr:$total_num_threads, I64Attr:$desired_bit_alignment); - let results = (outs TransformHandleTypeInterface:$transformed); + let results = (outs TransformHandleTypeInterface:$forall_op, + TransformHandleTypeInterface:$tiled_op); let assemblyFormat = [{ $target `total_num_threads` `=` $total_num_threads `desired_bit_alignment` `=` $desired_bit_alignment - attr-dict + attr-dict `:` functional-type(operands, results) }]; @@ -2380,7 +2384,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::transform::TransformRewriter &rewriter, - ::mlir::linalg::CopyOp copyOp, + ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3378,21 +3378,26 @@ //===----------------------------------------------------------------------===// // MapCopyToThreadsOp //===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( - transform::TransformRewriter &rewriter, linalg::CopyOp copyOp, + transform::TransformRewriter &rewriter, Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { - auto transformOp = cast(getOperation()); - ShapedType resultShapedType; - if (copyOp) { - resultShapedType = - cast(copyOp.getDpsInitOperand(0)->get().getType()); + // Check if the op is supported. + if (!isa(target)) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "only linalg.copy and tensor.pad target ops are supported"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; } - if (!copyOp || !resultShapedType.hasStaticShape()) { + assert(target->getNumResults() == 1 && "expected single result"); + auto resultShapedType = cast(target->getResult(0).getType()); + if (!resultShapedType.hasStaticShape()) { DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() - << "only statically sized linalg.copy ops of rank <= 3 are supported"; - diag.attachNote(copyOp->getLoc()) << "target op"; + emitSilenceableError() + << "only statically sized ops of rank <= 3 are supported"; + diag.attachNote(target->getLoc()) << "target op"; return diag; } @@ -3414,11 +3419,11 @@ resultShapedType.getElementType().getIntOrFloatBitWidth()); if (mapping.status == gpu::CopyMappingInfo::Status::Invalid) { DiagnosedSilenceableFailure diag = - transformOp.emitSilenceableError() + emitSilenceableError() << "too few threads to map copy op to threads on the most minor " "dimension, given alignment and vector size constraints, try " "smaller tile size of mapping to more threads"; - diag.attachNote(copyOp->getLoc()) << "target op"; + diag.attachNote(target->getLoc()) << "target op"; return diag; } @@ -3428,8 +3433,8 @@ DiagnosedSilenceableFailure diag = tileToForallOpImpl( /*rewriter=*/rewriter, /*state=*/state, - /*transformOp=*/transformOp, - /*target=*/copyOp, + /*transformOp=*/*this, + /*target=*/target, /*mixedNumThreads=*/getMixedValues(mapping.numThreads, {}, b), /*mixedTileSizes=*/ArrayRef{}, /*mapping=*/b.getArrayAttr(mapping.threadMapping), @@ -3437,6 +3442,7 @@ if (!diag.succeeded()) return diag; + results.push_back(tilingResult.tileOp); results.push_back(tilingResult.tiledOp); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir b/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir --- a/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-gpu-map-copy-to-threads.mlir @@ -20,7 +20,37 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) +} + +// ----- + +!tt = tensor<8xf16> +!tin = tensor + +// CHECK-LABEL: func @pad_1d_8xf16 +func.func @pad_1d_8xf16(%t0: !tin, %sz: index) -> !tt { + %cst = arith.constant 0.0 : f16 + /// Too little data for all threads, needs predication, while keeping most + /// minor transfer size -> 1 thread. + // CHECK: scf.forall {{.*}} in (1) {{.*}} + // CHECK: %[[padded:.*]] = tensor.pad {{.*}} + // CHECK: tensor.cast %[[padded]] : tensor to tensor<8xf16> + // CHECK: {mapping = [#gpu.linear]} + %0 = tensor.pad %t0 low[0] high[%sz] { + ^bb0(%arg0: index): + tensor.yield %cst : f16 + } : !tin to !tt + return %0 : !tt +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.gpu.map_copy_to_threads %0 + total_num_threads = 32 desired_bit_alignment = 128 + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"tensor.pad">) } // ----- @@ -44,7 +74,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -68,7 +98,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } @@ -93,7 +123,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -117,7 +147,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -140,7 +170,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -162,7 +192,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -184,7 +214,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 64 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -206,7 +236,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -228,7 +258,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 8 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -254,7 +284,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 8 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -277,7 +307,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 128 desired_bit_alignment = 8 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -300,7 +330,7 @@ : (!transform.any_op) -> !transform.any_op transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 128 desired_bit_alignment = 64 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } @@ -330,7 +360,7 @@ // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}} transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -355,7 +385,7 @@ // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}} transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 128 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -379,7 +409,7 @@ // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}} transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 8 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) } // ----- @@ -403,5 +433,5 @@ // expected-error @below {{too few threads to map copy op to threads on the most minor dimension, given alignment and vector size constraints}} transform.structured.gpu.map_copy_to_threads %0 total_num_threads = 32 desired_bit_alignment = 8 - : (!transform.any_op) -> (!transform.op<"linalg.copy">) + : (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.copy">) }