diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -900,4 +900,61 @@ }]; } +def MapNestedForeachThreadToGpuBlocks : Op { + let description = [{ + Target the gpu_launch op and rewrite all scf.foreach_thread + to distributed gpu.block_id attribute. + + The operation searches `scf.foreach_thread` ops nested under `target` + and maps each such op to GPU thread blocks. Mapping is one-to-one and the + induction variables of `scf.foreach_thread` are rewritten to + gpu.thread_id according to the gridDim attribute. + + It requires a top-level `scf.foreach_thread`, siblings are not supported. + + Dynamic, `scf.foreach_thread` trip counts are currently not supported. + Dynamic block dim sizes are currently not supported. + + Only **bufferized** scf.foreach_thread are currently supported. + Only scf.foreach_thread distributed to **at most 3 dimensions** are + currently supported. + + The operation alters the block size of the given gpu_launch using + gridDim argument. + + Return modes: + ============= + This operation ignores non-gpu_launch ops and drops them in the return. + + If any scf.foreach_thread with tensors is found, the transform definitely + fails. + + If all the scf.foreach_thread operations contained within the LaunchOp + referred to by the `target` PDLOperation lower to GPU properly, the + transform succeeds. Otherwise the transform definitely fails. + + The returned handle points to the same LaunchOp operand, consuming it and + producing a new SSA value to satisfy chaining and linearity of the IR + properties. + ... + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$gridDim); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1166,6 +1166,111 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// MapNestedForeachThreadToGpuBlocks +//===----------------------------------------------------------------------===// +static LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(RewriterBase &rewriter, +scf::ForeachThreadOp foreachThreadOp) { + + if (foreachThreadOp->getNumResults() > 0) + return foreachThreadOp.emitOpError( + "only bufferized scf.foreach_thread lowers to gpu.block_id"); + if (foreachThreadOp.getNumThreads().size() > 3) + return foreachThreadOp.emitOpError( + "scf.foreach_thread with rank > 3 does not lower to gpu.block_id"); + + // Step 0. Outline the compute workload region and set up the workload + // operands. + auto potentialGridDim = foreachThreadOp.getPermutedNumThreads(rewriter); + if (failed(potentialGridDim) || + llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) + return foreachThreadOp.emitOpError("unsupported dynamic gridDim size"); + + SmallVector gridDims = + llvm::to_vector(llvm::map_range(*potentialGridDim, [](OpFoldResult ofr) { + return getConstantIntValue(ofr).value(); + })); + + Location loc = foreachThreadOp.getLoc(); + IndexType indexType = rewriter.getIndexType(); + rewriter.setInsertionPoint(foreachThreadOp); + SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, + gpu::Dimension::z}; + SmallVector blockOps; + for (int64_t idx : llvm::seq(0, gridDims.size())) { + blockOps.push_back( + rewriter.create(loc, indexType, gpuDims[idx])); + } + + // Step 3. Move the body of foreachThreadOp. + // Erase the terminator first, it will not be used since we are on buffers. + rewriter.eraseOp(foreachThreadOp.getTerminator()); + Block *targetBlock; + Block::iterator insertionPoint; + targetBlock = foreachThreadOp->getBlock(); + insertionPoint = Block::iterator(foreachThreadOp); + Block &sourceBlock = foreachThreadOp.getRegion().front(); + targetBlock->getOperations().splice(insertionPoint, + sourceBlock.getOperations()); + + // // Step 4. RAUW thread indices to thread ops. + SmallVector threadIndices = *foreachThreadOp.getPermutedThreadIndices(); + assert(blockOps.size() == 3 && "3 workgroup id ops are required"); + assert(threadIndices.size() == 3 && "3 thread id dimensions are required"); + for (auto it : llvm::zip(threadIndices, blockOps)) { + Value val = std::get<0>(it); + if (!val) continue; + for (Operation *user : llvm::make_early_inc_range(val.getUsers())) { + rewriter.updateRootInPlace( + user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); }); + } + } + + // Step 6. Erase old op. + rewriter.eraseOp(foreachThreadOp); + + return success(); +} + + +DiagnosedSilenceableFailure +transform::MapNestedForeachThreadToGpuBlocks::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + + gpu::LaunchOp gpuLaunch = dyn_cast(target); + if (!gpuLaunch) { + target->emitError("Given target is not gpu.launch"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + scf::ForeachThreadOp topLevelForeachThreadOp; + auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + if (foreachThreadOp->getParentOfType()) + return WalkResult::advance(); + if (topLevelForeachThreadOp) return WalkResult::interrupt(); + topLevelForeachThreadOp = foreachThreadOp; + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) { + target->emitOpError( + "could not find a unique topLevel scf.foreach_thread"); + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + } + + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + + if (failed(rewriteTopLevelForeachThreadToGpuBlocks(rewriter, topLevelForeachThreadOp))) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + results.assign({target}); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // MapNestedForeachThreadToGpuThreads //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir --- a/mlir/test/Dialect/Linalg/transform-gpu.mlir +++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir @@ -3,6 +3,48 @@ !type = memref<2 x 32 x f32> !type1d = memref<32 x f32> +// CHECK-LABEL: func.func @saxpy2dblock( +// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGT:[0-9a-z]+]]: memref<32xf32> +func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %one = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index +// CHECK: gpu.launch +// CHECK: %[[BLKX:.*]] = gpu.block_id x +// CHECK: %[[BLKY:.*]] = gpu.block_id y +// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]]] +// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]]] + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j) in (%c7, %c9) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } {thread_dim_mapping = [0, 1, 2]} + gpu.terminator + } + return %y : !type +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { blockDim = [12, 9, 1] } + } +} + +// ----- + +!type = memref<2 x 32 x f32> +!type1d = memref<32 x f32> + // CHECK-LABEL: func.func @saxpy2d( // CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32> // CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32>