diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -748,6 +748,104 @@ }]; } +def MapNestedForeachThreadToGpuThreads : + Op { + let description = [{ + Target the gpu_launch op and rewrite all scf.foreach_thread + to distributed gpu.thread_id attribute. + + The operation searches `scf.foreach_thread` ops nested under `target` + and maps each such op to GPU threads. Mapping is one-to-one and the + induction variables of `scf.foreach_thread` are rewritten to + gpu.thread_id according to the thread_dim_apping attribute. + + Sibling `scf.foreach_thread` are supported in which case, the union of + the number of threads is computed and may result in predication. + + Multiple scf.foreach_thread are supported per function in which case, the + max of all the threads is computed and taken for the global gpu.thread_id. + If necessary, scf.foreach_thread that do not use the whole thread range + result in predicated computations. + + Dynamic, `scf.foreach_thread` trip counts are currently not supported. + Dynamic block dim sizes are currently not supported. + + Only **bufferized** scf.foreach_thread are currently supported. + Only scf.foreach_thread distributed to **at most 3 dimensions** are + currently supported. + + Barriers are inserted after each scf.foreach_thread op for now. + + The operation alters the block size of the given gpu_launch using + blockDim argument. + + Return modes: + ============= + This operation ignores non-gpu_launch ops and drops them in the return. + + If any scf.foreach_thread with tensors is found, the transform definitely + fails. + + If all the scf.foreach_thread operations contained within the LaunchOp + referred to by the `target` PDLOperation lower to GPU properly, the + transform succeeds. Otherwise the transform definitely fails. + + The returned handle points to the same LaunchOp operand, consuming it and + producing a new SSA value to satisfy chaining and linearity of the IR + properties. + + Example: + ======== + + ``` + gpu.launch blocks(%bx, %by, %bz) in (%x = %0, %y = %1, %z = %2) + threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) { + scf.foreach_thread (%i, %j) in (7, 9) { + ... // body 1 + } {thread_dim_mapping = [1, 0, 2]} + scf.foreach_thread (%i) in (12) { + ... // body 2 + } + gpu.terminator + } + ``` + is translated to: + + ``` + %bdimX = arith.constant 12 : index + %bdimY = arith.constant 9 : index + gpu.launch blocks(%bx, %by, %bz) in (%x = %0, %y = %1, %z = %2) + threads(%tx, %ty, %tz) in (%tx = %bdimX, %ty = %bdimY, %tz = %5) { + if (threadIdx.x < 9 && threadIdx.y < 7) { + ... // body 1 + } + gpu.barrier + if (threadIdx.y < 1) { + ... // body 2 + } + gpu.barrier + gpu.terminator + } + ``` + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$blockDim); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def VectorizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -121,6 +121,17 @@ FailureOr fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); +/// Searches `scf.foreach_thread` ops nested under `target` and maps each such +/// op to GPU threads. Mapping is one-to-one and the induction variables of +/// `scf.foreach_thread` are rewritten to gpu.thread_id according to the +/// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in +/// which case, the union of the number of threads is computed and may result in +/// predication. Dynamic, `scf.foreach_thread` trip counts are currently not +/// supported. Dynamic block dim sizes are currently not supported. +mlir::WalkResult rewriteMapNestedForeachThreadToGpuThreads( + RewriterBase &rewriter, Operation *target, + const SmallVector &blockDim, bool syncAfterDistribute); + /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. /// diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -501,6 +501,16 @@ return getBody()->getArguments().drop_front(getRank()); } + /// Return the thread indices in the order specified by the + /// thread_dim_mapping attribute. Return failure is + /// thread_dim_mapping is not a valid permutation. + FailureOr> getPermutedThreadIndices(); + + /// Return the number of threads in the order specified by the + /// thread_dim_mapping attribute. + /// Return failure is thread_dim_mapping is not a valid permutation. + FailureOr> getPermutedNumThreads(OpBuilder &b); + // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be // well-formed. We override it here to ensure that we do the right thing. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -11,6 +11,7 @@ #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -1165,6 +1166,175 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// MapNestedForeachThreadToGpuThreads +//===----------------------------------------------------------------------===// + +/// Searches `scf.foreach_thread` ops nested under `target` and maps each such +/// op to GPU threads. Mapping is one-to-one and the induction variables of +/// `scf.foreach_thread` are rewritten to gpu.thread_id according to the +/// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in +/// which case, the union of the number of threads is computed and may result in +/// predication. Dynamic, `scf.foreach_thread` trip counts are currently not +/// supported. Dynamic block dim sizes are currently not supported. +static FailureOr> rewriteOneForeachThreadToGpuThreads( + RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, + const SmallVector &globalBlockDims, bool syncAfterDistribute) { + if (foreachThreadOp.getNumResults() > 0) + return foreachThreadOp->emitError( + "only bufferized scf.foreach_thread lowers to gpu.thread"); + if (foreachThreadOp.getNumThreads().size() > 3) + return foreachThreadOp->emitError( + "scf.foreach_thread with rank > 3 does not lower to gpu.thread"); + + auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter); + if (failed(potentialBlockDim) || + llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) + return foreachThreadOp->emitError("unsupported dynamic blockdim size"); + + SmallVector blockDim = + llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) { + return getConstantIntValue(ofr).value(); + })); + + // Step 1. Create the gpu.thread ops + Location loc = foreachThreadOp.getLoc(); + IndexType indexType = rewriter.getIndexType(); + + SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, + gpu::Dimension::z}; + SmallVector threadOps; + for (int64_t idx : llvm::seq(0, blockDim.size())) { + threadOps.push_back( + rewriter.create(loc, indexType, gpuDims[idx])); + } + // Step 2. Maybe create conditionals to predicate the region. + Value predicate; + for (auto [threadId, blockDim, globalBlockDim] : + llvm::zip(threadOps, blockDim, globalBlockDims)) { + if (blockDim > globalBlockDim) { + return foreachThreadOp.emitOpError("blockDim size overflow: ") + << blockDim << " > " << globalBlockDim; + } + if (blockDim == globalBlockDim) + continue; + Value tmpPredicate = rewriter.create( + loc, arith::CmpIPredicate::ult, threadId, + rewriter.create(loc, blockDim)); + predicate = + predicate ? rewriter.create(loc, predicate, tmpPredicate) + : tmpPredicate; + } + + // Step 3. Move the body of foreachThreadOp. + // Erase the terminator first, it will not be used. + rewriter.eraseOp(foreachThreadOp.getTerminator()); + Block *targetBlock; + Block::iterator insertionPoint; + if (predicate) { + // Step 3.a. If predicated, move at the beginning. + auto ifOp = + rewriter.create(loc, predicate, /*withElseRegion=*/false); + targetBlock = ifOp.thenBlock(); + insertionPoint = ifOp.thenBlock()->begin(); + } else { + // Step 3.a. Otherwise, move inline just before foreachThreadOp. + targetBlock = foreachThreadOp->getBlock(); + insertionPoint = Block::iterator(foreachThreadOp); + } + Block &sourceBlock = foreachThreadOp.getRegion().front(); + targetBlock->getOperations().splice(insertionPoint, + sourceBlock.getOperations()); + + // Step 4. RAUW thread indices to thread ops. + SmallVector threadIndices = + *foreachThreadOp.getPermutedThreadIndices(); + for (auto it : llvm::zip(threadIndices, threadOps)) { + Value val = std::get<0>(it); + if (!val) + continue; + for (Operation *user : llvm::make_early_inc_range(val.getUsers())) { + rewriter.updateRootInPlace( + user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); }); + } + } + + // Step 5. syncthreads. + // TODO: Need warpsync + if (syncAfterDistribute) + rewriter.create(loc); + + // Step 6. Erase old op. + rewriter.eraseOp(foreachThreadOp); + + return *potentialBlockDim; +} + +mlir::WalkResult mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads( + RewriterBase &rewriter, Operation *target, + const SmallVector &blockDim, bool syncAfterDistribute) { + auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + rewriter.setInsertionPoint(foreachThreadOp); + if (failed(rewriteOneForeachThreadToGpuThreads(rewriter, foreachThreadOp, + blockDim, true))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return walkResult; +} + +// Alter blockDim of the given kernel +static LogicalResult alterGpuLaunchBlockDim(SimpleRewriter &rewriter, + gpu::LaunchOp gpuLaunch, + SmallVector blockDim) { + gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); + if (blockDim[0] < 1 || blockDim[1] < 1 || blockDim[2] < 1) { + gpuLaunch->emitError() << "Given blockDim(" << blockDim[0] << "," + << blockDim[1] << "," << blockDim[2] + << ") is invalid"; + return failure(); + } + rewriter.setInsertionPointAfterValue(currentBlockdim.x); + auto createBlockDimValue = [&](int64_t dim) { + return rewriter.create(currentBlockdim.x.getLoc(), + dim); + }; + gpuLaunch.blockSizeXMutable().assign(createBlockDimValue(blockDim[0])); + gpuLaunch.blockSizeYMutable().assign(createBlockDimValue(blockDim[1])); + gpuLaunch.blockSizeZMutable().assign(createBlockDimValue(blockDim[2])); + return success(); +} + +DiagnosedSilenceableFailure +transform::MapNestedForeachThreadToGpuThreads::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + + gpu::LaunchOp gpuLaunch = dyn_cast(target); + if (!gpuLaunch) { + target->emitError("Given target is not gpu.launch"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + SmallVector blockDim = extractFromI64ArrayAttr(getBlockDim()); + blockDim.resize(/*size=*/3, /*value=*/1); + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + auto walkResult = mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads( + rewriter, target, blockDim, true); + if (walkResult.wasInterrupted()) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + LogicalResult result = alterGpuLaunchBlockDim(rewriter, gpuLaunch, blockDim); + if (failed(result)) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.assign({target}); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1244,6 +1244,61 @@ return cast(getBody()->getTerminator()); } +template +static FailureOr> permute(const SmallVector &vals, + ArrayRef perm) { + if (vals.size() != perm.size()) + return failure(); + SmallVector result(vals.size()); + SmallVector seen(vals.size()); + for (auto [idx, val] : llvm::zip(perm, vals)) { + // Already seen, invalid thread_dim_mapping. + if (seen[idx]) + return failure(); + result[idx] = val; + seen[idx] = true; + } + // Some not seen, invalid thread_dim_mapping. + if (!llvm::all_of(seen, [](bool b) { return b; })) + return failure(); + return result; +} + +/// Helper to get apply the `thread_dim_mapping` permutation of a +/// `foreachThreadOp` to `values`. +template +static FailureOr> +getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp, + const SmallVector &values) { + // Apply mapping permutation if specified. + auto mapping = foreachThreadOp.getThreadDimMapping(); + if (mapping && !mapping.empty()) { + auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); + if (failed(maybePermuted)) + return foreachThreadOp->emitError("invalid permutation"); + return *maybePermuted; + } + return values; +} + +/// Return the thread indices in the order specified by the thread_dim_mapping +/// attribute. Return failure is thread_dim_mapping is not a valid permutation. +FailureOr> ForeachThreadOp::getPermutedThreadIndices() { + SmallVector threadCountValues = this->getThreadIndices(); + threadCountValues.resize(3, Value()); + return getValuesPermutedByThreadMapping(*this, threadCountValues); +} + +/// Return the number of threads in the order specified by the +/// thread_dim_mapping attribute. +/// Return failure is thread_dim_mapping is not a valid permutation. +FailureOr> +ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) { + SmallVector threadCountValues = this->getNumThreads(); + threadCountValues.resize(3, b.getIndexAttr(1)); + return getValuesPermutedByThreadMapping(*this, threadCountValues); +} + ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { auto tidxArg = val.dyn_cast(); if (!tidxArg) diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s + +!type = memref<2 x 32 x f32> +!type1d = memref<32 x f32> + +// CHECK-LABEL: func.func @saxpy2d( +// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGT:[0-9a-z]+]]: memref<32xf32> +func.func @saxpy2d(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %one = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index +// CHECK: gpu.launch +// CHECK: %[[TIDX:.*]] = gpu.thread_id x +// CHECK: %[[TIDY:.*]] = gpu.thread_id y +// CHECK: %[[C9:.*]] = arith.constant 9 : index +// CHECK: arith.cmpi ult, %[[TIDX]], %[[C9]] : index +// CHECK: %[[C7:.*]] = arith.constant 7 : index +// CHECK: arith.cmpi ult, %[[TIDY]], %[[C7]] : index +// CHECK: memref.load %[[ARGX]][%[[TIDY]], %[[TIDX]]] +// CHECK: memref.load %[[ARGY]][%[[TIDY]], %[[TIDX]]] +// CHECK: gpu.barrier +// CHECK: %[[TIDX2:.*]] = gpu.thread_id x +// CHECK: %[[TIDY2:.*]] = gpu.thread_id y +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: arith.cmpi ult, %[[TIDY2]], %[[C1]] : index +// CHECK: memref.load %[[ARGT]][%[[TIDX2]]] +// CHECK: gpu.barrier + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j) in (%c7, %c9) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } {thread_dim_mapping = [1, 0, 2]} + scf.foreach_thread (%i) in (%c12) { + %7 = memref.load %t[%i] : !type1d + %8 = arith.addf %alpha, %7 : f32 + memref.store %8, %t[%i] : !type1d + } {thread_dim_mapping = [0, 1, 2]} + gpu.terminator + } + return %y : !type +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.structured.map_nested_foreach_thread_to_gpu_threads %funcop { blockDim = [12, 9, 1] } + } +} +