diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -846,6 +846,63 @@ }]; } +def MapNestedForeachThreadToGpuBlocks : Op { + let description = [{ + Target the gpu_launch op and rewrite the top level `scf.foreach_thread` + to distributed gpu.block_id attribute. If `generate_gpu_launch` attribute + is set, then first generates `gpu_launch` and moves the top level + `scf.foreach_thread` inside. + + The operation searches top level `scf.foreach_thread` ops under + `gpu_launch` and maps each such op to GPU blocks. Mapping is + one-to-one and the induction variables of `scf.foreach_thread` are + rewritten to gpu.block_id according to the `thread_dim_apping` attribute. + + Dynamic, `scf.foreach_thread` trip counts are currently not supported. + Dynamic block dim sizes are currently not supported. + + Only **bufferized** scf.foreach_thread are currently supported. + Only scf.foreach_thread distributed to **at most 3 dimensions** are + currently supported. + + The operation alters the block size of the given gpu_launch using + gridDim argument. + + Return modes: + ============= + This operation ignores non-gpu_launch ops and drops them in the return. + + If any scf.foreach_thread with tensors is found, the transform definitely + fails. + + If all the scf.foreach_thread operations contained within the LaunchOp + referred to by the `target` PDLOperation lower to GPU properly, the + transform succeeds. Otherwise the transform definitely fails. + + The returned handle points to the same LaunchOp operand, consuming it and + producing a new SSA value to satisfy chaining and linearity of the IR + properties. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$gridDim, + UnitAttr:$generate_gpu_launch); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def VectorizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -125,6 +125,21 @@ FailureOr fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); +/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is +/// one-to-one and the induction variables of `scf.foreach_thread` are rewritten +/// to gpu.block_id according to the thread_dim_apping attribute. Dynamic, +/// `scf.foreach_thread` trip counts are currently not supported. Dynamic block +/// dim sizes are currently not supported. +LogicalResult rewriteTopLevelForeachThreadToGpuBlocks( + RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, + function_ref &, IndexType, + SmallVector &)> + blockIdGenerator, + SmallVector &gridDims); + +/// Finds the top level scf::ForeachThreadOp of given target. +FailureOr findTopLevelForeachThreadOp(Operation *target); + /// Searches `scf.foreach_thread` ops nested under `target` and maps each such /// op to GPU threads. Mapping is one-to-one and the induction variables of /// `scf.foreach_thread` are rewritten to gpu.thread_id according to the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -18,9 +18,11 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::linalg; @@ -1285,25 +1287,36 @@ return walkResult; } -// Alter blockDim of the given kernel -static LogicalResult alterGpuLaunchBlockDim(SimpleRewriter &rewriter, - gpu::LaunchOp gpuLaunch, - SmallVector blockDim) { - gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); - if (blockDim[0] < 1 || blockDim[1] < 1 || blockDim[2] < 1) { - gpuLaunch->emitError() << "Given blockDim(" << blockDim[0] << "," - << blockDim[1] << "," << blockDim[2] - << ") is invalid"; +/// Alter grid or block dimensions of the given kernel +static LogicalResult +alterGpuLaunch(SimpleRewriter &rewriter, gpu::LaunchOp gpuLaunch, + unsigned int gridDimX = 1, unsigned int gridDimY = 1, + unsigned int gridDimZ = 1, unsigned int blockDimX = 1, + unsigned int blockDimY = 1, unsigned int blockDimZ = 1) { + // TODO The limits should live in the gpu dialect, but it's not like that + // right now. Read them in the common gpu dialect + if ((blockDimX * blockDimY * blockDimZ) > 1024 || gridDimY > 65535 || + gridDimZ > 65535 || gridDimX > 2147483647) return failure(); - } + + gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); rewriter.setInsertionPointAfterValue(currentBlockdim.x); - auto createBlockDimValue = [&](int64_t dim) { + auto createGridDimValue = [&](int dim) { return rewriter.create(currentBlockdim.x.getLoc(), dim); }; - gpuLaunch.blockSizeXMutable().assign(createBlockDimValue(blockDim[0])); - gpuLaunch.blockSizeYMutable().assign(createBlockDimValue(blockDim[1])); - gpuLaunch.blockSizeZMutable().assign(createBlockDimValue(blockDim[2])); + if (gridDimX != 1) + gpuLaunch.gridSizeXMutable().assign(createGridDimValue(gridDimX)); + if (gridDimY != 1) + gpuLaunch.gridSizeYMutable().assign(createGridDimValue(gridDimY)); + if (gridDimZ != 1) + gpuLaunch.gridSizeZMutable().assign(createGridDimValue(gridDimZ)); + if (blockDimX != 1) + gpuLaunch.gridSizeXMutable().assign(createGridDimValue(blockDimX)); + if (blockDimY != 1) + gpuLaunch.gridSizeYMutable().assign(createGridDimValue(blockDimY)); + if (blockDimZ != 1) + gpuLaunch.gridSizeZMutable().assign(createGridDimValue(blockDimZ)); return success(); } @@ -1327,7 +1340,8 @@ if (walkResult.wasInterrupted()) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - LogicalResult result = alterGpuLaunchBlockDim(rewriter, gpuLaunch, blockDim); + LogicalResult result = alterGpuLaunch(rewriter, gpuLaunch, 0, 0, 0, + blockDim[0], blockDim[1], blockDim[2]); if (failed(result)) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1335,6 +1349,164 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// MapNestedForeachThreadToGpuBlocks +//===----------------------------------------------------------------------===// + +LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks( + RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, + function_ref &, IndexType, + SmallVector &)> + blockIdGenerator, + SmallVector &gridDims) { + if (foreachThreadOp.getNumResults() > 0) + return foreachThreadOp->emitError( + "only bufferized scf.foreach_thread lowers to gpu.block_id"); + if (foreachThreadOp.getNumThreads().size() > 3) + return foreachThreadOp->emitError( + "scf.foreach_thread with rank > 3 does not lower to gpu.block_id"); + + // Step 0. Outline the compute workload region and set up the workload + // operands. + auto potentialGridDim = foreachThreadOp.getPermutedNumThreads(rewriter); + if (failed(potentialGridDim) || + llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) + return foreachThreadOp->emitError("unsupported dynamic gridDim"); + + for (OpFoldResult ofr : *potentialGridDim) + gridDims.push_back(getConstantIntValue(ofr).value()); + + IndexType indexType = rewriter.getIndexType(); + SmallVector blockOps; + blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps); + + // Step 1. Move the body of foreachThreadOp. + // Erase the terminator first, it will not be used since we are on buffers. + rewriter.eraseOp(foreachThreadOp.getTerminator()); + Block *targetBlock = foreachThreadOp->getBlock(); + Block::iterator insertionPoint = Block::iterator(foreachThreadOp); + Block &sourceBlock = foreachThreadOp.getRegion().front(); + targetBlock->getOperations().splice(insertionPoint, + sourceBlock.getOperations()); + + // Step 2. RAUW thread indices to thread ops. + SmallVector threadIndices = + *foreachThreadOp.getPermutedThreadIndices(); + assert(blockOps.size() == 3 && "3 block id ops are required"); + for (auto it : llvm::zip(threadIndices, blockOps)) { + Value val = std::get<0>(it); + if (!val) + continue; + for (Operation *user : llvm::make_early_inc_range(val.getUsers())) { + rewriter.updateRootInPlace( + user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); }); + } + } + + // Step 3. Erase old op. + rewriter.eraseOp(foreachThreadOp); + + return success(); +} + +FailureOr +mlir::linalg::findTopLevelForeachThreadOp(Operation *target) { + scf::ForeachThreadOp topLevelForeachThreadOp; + auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + if (foreachThreadOp->getParentOfType()) + return WalkResult::advance(); + if (topLevelForeachThreadOp) + // TODO Handle multiple foreach if there is no dependences between them + return WalkResult::interrupt(); + topLevelForeachThreadOp = foreachThreadOp; + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return target->emitError( + "could not find a unique topLevel scf.foreach_thread"); + + return topLevelForeachThreadOp; +} + +/// Create gpuLauncOp with given kernel configurations +static gpu::LaunchOp createGpuLaunch(RewriterBase &rewriter, Location loc, + int gridDimX = 1, int gridDimY = 1, + int gridDimZ = 1, int blockDimX = 1, + int blockDimY = 1, int blockDimZ = 1) { + auto createConstant = [&](int dim) { + return rewriter.create(loc, dim); + }; + Value gridSizeX = createConstant(gridDimX); + Value gridSizeY = createConstant(gridDimY); + Value gridSizeZ = createConstant(gridDimZ); + Value blockSizeX = createConstant(blockDimX); + Value blockSizeY = createConstant(blockDimY); + Value blockSizeZ = createConstant(blockDimZ); + auto launchOp = rewriter.create( + loc, gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ); + rewriter.setInsertionPointToEnd(&launchOp.body().front()); + rewriter.create(loc); + return launchOp; +} + +DiagnosedSilenceableFailure +transform::MapNestedForeachThreadToGpuBlocks::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + gpu::LaunchOp gpuLaunch = dyn_cast(target); + SimpleRewriter rewriter(getContext()); + + if (!getGenerateGpuLaunch() && !gpuLaunch) { + target->emitError("Given target is not gpu.launch, set " + "`generate_gpu_launch` attribute"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + auto res = mlir::linalg::findTopLevelForeachThreadOp(target); + if (failed(res)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + scf::ForeachThreadOp topLevelForeachThreadOp = *res; + rewriter.setInsertionPoint(topLevelForeachThreadOp); + + // Generate gpu launch here and move the foreach_thread inside + if (getGenerateGpuLaunch()) { + gpuLaunch = createGpuLaunch(rewriter, target->getLoc()); + rewriter.setInsertionPointToStart(&gpuLaunch.body().front()); + Operation *newForeachThreadOp = rewriter.clone(*topLevelForeachThreadOp); + rewriter.eraseOp(topLevelForeachThreadOp); + topLevelForeachThreadOp = + dyn_cast(newForeachThreadOp); + } + + auto generateBlocks = [&](Operation *op, const SmallVector &gridDims, + IndexType indexType, SmallVector &blockOps) { + Location loc = op->getLoc(); + rewriter.setInsertionPoint(op); + SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, + gpu::Dimension::z}; + for (int64_t idx : llvm::seq(0, gridDims.size())) { + blockOps.push_back( + rewriter.create(loc, indexType, gpuDims[idx])); + } + }; + + SmallVector gridDim = extractFromI64ArrayAttr(getGridDim()); + if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks( + rewriter, topLevelForeachThreadOp, generateBlocks, gridDim))) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1], + gridDim[2]))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.assign({gpuLaunch}); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// @@ -1548,6 +1720,7 @@ declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir --- a/mlir/test/Dialect/Linalg/transform-gpu.mlir +++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir @@ -3,6 +3,47 @@ !type = memref<2 x 32 x f32> !type1d = memref<32 x f32> +// CHECK-LABEL: func.func @saxpy2dblock( +// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGT:[0-9a-z]+]]: memref<32xf32> +func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index + %one = arith.constant 1 : index +// CHECK: gpu.launch +// CHECK: %[[BLKX:.*]] = gpu.block_id x +// CHECK: %[[BLKY:.*]] = gpu.block_id y +// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]]] +// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]]] + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j) in (%c7, %c9) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } {thread_dim_mapping = [0, 1, 2]} + gpu.terminator + } + return %y : !type +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { blockDim = [12, 9, 1]} + } +} + +// ----- + +!type = memref<2 x 32 x f32> +!type1d = memref<32 x f32> + // CHECK-LABEL: func.func @saxpy2d( // CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32> // CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32> @@ -56,3 +97,41 @@ } } +// ----- + +!type4d = memref<32x64x4x32xf32> + +// CHECK-LABEL: func.func @saxpy4d( +// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<32x64x4x32xf32> +// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<32x64x4x32xf32> +func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d { + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index +// CHECK: gpu.launch +// CHECK: %[[BLKX:.*]] = gpu.block_id x +// CHECK: %[[BLKY:.*]] = gpu.block_id y +// CHECK: %[[TIDX:.*]] = gpu.thread_id x +// CHECK: %[[TIDY:.*]] = gpu.thread_id y +// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]] +// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]] + scf.foreach_thread (%i, %j) in (%c32, %c64) { + scf.foreach_thread (%k, %l) in (%c4, %c32) { + %4 = memref.load %x[%i, %j, %k, %l] : !type4d + %5 = memref.load %y[%i, %j, %k, %l] : !type4d + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j, %k, %l] : !type4d + } {thread_dim_mapping = [1, 0, 2]} + } {thread_dim_mapping = [0, 1, 2]} + return %y : !type4d +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["func.func"]} in %arg0 + %gpuLaunch = transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { generate_gpu_launch } + transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [32, 4, 1] } + } +}