diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -846,6 +846,63 @@ }]; } +def MapNestedForeachThreadToGpuBlocks : Op { + let description = [{ + Target the gpu_launch op and rewrite the top level `scf.foreach_thread` + to distributed gpu.block_id attribute. If `generate_gpu_launch` attribute + is set, then first generates `gpu_launch` and moves the top level + `scf.foreach_thread` inside. + + The operation searches top level `scf.foreach_thread` ops under + `gpu_launch` and maps each such op to GPU blocks. Mapping is + one-to-one and the induction variables of `scf.foreach_thread` are + rewritten to gpu.block_id according to the `thread_dim_apping` attribute. + + Dynamic, `scf.foreach_thread` trip counts are currently not supported. + Dynamic block dim sizes are currently not supported. + + Only **bufferized** scf.foreach_thread are currently supported. + Only scf.foreach_thread distributed to **at most 3 dimensions** are + currently supported. + + The operation alters the block size of the given gpu_launch using + gridDim argument. + + #### Return modes: + + This operation ignores non-gpu_launch ops and drops them in the return. + + If any scf.foreach_thread with tensors is found, the transform definitely + fails. + + If all the scf.foreach_thread operations contained within the LaunchOp + referred to by the `target` PDLOperation lower to GPU properly, the + transform succeeds. Otherwise the transform definitely fails. + + The returned handle points to the same LaunchOp operand, consuming it and + producing a new SSA value to satisfy chaining and linearity of the IR + properties. + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$gridDim, + UnitAttr:$generate_gpu_launch); + let results = (outs PDL_Operation:$result); + + let assemblyFormat = "$target attr-dict"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def VectorizeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -125,6 +125,21 @@ FailureOr fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); +/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is +/// one-to-one and the induction variables of `scf.foreach_thread` are rewritten +/// to gpu.block_id according to the thread_dim_apping attribute. Dynamic, +/// `scf.foreach_thread` trip counts are currently not supported. Dynamic block +/// dim sizes are currently not supported. +LogicalResult rewriteTopLevelForeachThreadToGpuBlocks( + RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, + function_ref &, IndexType, + SmallVector &)> + blockIdGenerator, + SmallVector &gridDims); + +/// Finds the top level scf::ForeachThreadOp of given target. +FailureOr findTopLevelForeachThreadOp(Operation *target); + /// Searches `scf.foreach_thread` ops nested under `target` and maps each such /// op to GPU threads. Mapping is one-to-one and the induction variables of /// `scf.foreach_thread` are rewritten to gpu.thread_id according to the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1285,25 +1285,56 @@ return walkResult; } -// Alter blockDim of the given kernel -static LogicalResult alterGpuLaunchBlockDim(SimpleRewriter &rewriter, - gpu::LaunchOp gpuLaunch, - SmallVector blockDim) { - gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); - if (blockDim[0] < 1 || blockDim[1] < 1 || blockDim[2] < 1) { - gpuLaunch->emitError() << "Given blockDim(" << blockDim[0] << "," - << blockDim[1] << "," << blockDim[2] - << ") is invalid"; +static LogicalResult +checkGpuLimits(Optional gridDimX, Optional gridDimY, + Optional gridDimZ, Optional blockDimX, + Optional blockDimY, Optional blockDimZ) { + // TODO The limits should live in the gpu dialect, but it's not like that + // right now. Read them in the common gpu dialect + if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) > + 1024 || + gridDimY.value_or(1) > 65535 || gridDimZ.value_or(1) > 65535 || + gridDimX.value_or(1) > 2147483647) + return failure(); + return success(); +} + +/// Alter grid or block dimensions of the given kernel +static LogicalResult alterGpuLaunch(SimpleRewriter &rewriter, + gpu::LaunchOp gpuLaunch, + Optional gridDimX = llvm::None, + Optional gridDimY = llvm::None, + Optional gridDimZ = llvm::None, + Optional blockDimX = llvm::None, + Optional blockDimY = llvm::None, + Optional blockDimZ = llvm::None) { + if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ))) { + gpuLaunch->emitError( + "Requested kernel thread configuration is larger than the limits"); return failure(); } + + gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues(); + OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(currentBlockdim.x); - auto createBlockDimValue = [&](int64_t dim) { + auto createConstValue = [&](int dim) { return rewriter.create(currentBlockdim.x.getLoc(), dim); }; - gpuLaunch.blockSizeXMutable().assign(createBlockDimValue(blockDim[0])); - gpuLaunch.blockSizeYMutable().assign(createBlockDimValue(blockDim[1])); - gpuLaunch.blockSizeZMutable().assign(createBlockDimValue(blockDim[2])); + + if (gridDimX.has_value()) + gpuLaunch.gridSizeXMutable().assign(createConstValue(gridDimX.value())); + if (gridDimY.has_value()) + gpuLaunch.gridSizeYMutable().assign(createConstValue(gridDimY.value())); + if (gridDimZ.has_value()) + gpuLaunch.gridSizeZMutable().assign(createConstValue(gridDimZ.value())); + if (blockDimX.has_value()) + gpuLaunch.blockSizeXMutable().assign(createConstValue(blockDimX.value())); + if (blockDimY.has_value()) + gpuLaunch.blockSizeYMutable().assign(createConstValue(blockDimY.value())); + if (blockDimZ.has_value()) + gpuLaunch.blockSizeZMutable().assign(createConstValue(blockDimZ.value())); return success(); } @@ -1327,7 +1358,9 @@ if (walkResult.wasInterrupted()) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - LogicalResult result = alterGpuLaunchBlockDim(rewriter, gpuLaunch, blockDim); + LogicalResult result = + alterGpuLaunch(rewriter, gpuLaunch, llvm::None, llvm::None, llvm::None, + blockDim[0], blockDim[1], blockDim[2]); if (failed(result)) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1335,6 +1368,184 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// MapNestedForeachThreadToGpuBlocks +//===----------------------------------------------------------------------===// + +LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks( + RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, + function_ref &, IndexType, + SmallVector &)> + blockIdGenerator, + SmallVector &gridDims) { + if (foreachThreadOp.getNumResults() > 0) + return foreachThreadOp->emitError( + "only bufferized scf.foreach_thread lowers to gpu.block_id"); + if (foreachThreadOp.getNumThreads().size() > 3) + return foreachThreadOp->emitError( + "scf.foreach_thread with rank > 3 does not lower to gpu.block_id"); + + // Step 0. Outline the compute workload region and set up the workload + // operands. + auto potentialGridDim = foreachThreadOp.getPermutedNumThreads(rewriter); + if (failed(potentialGridDim) || + llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) + return foreachThreadOp->emitError("unsupported dynamic gridDim"); + + for (OpFoldResult ofr : *potentialGridDim) + gridDims.push_back(getConstantIntValue(ofr).value()); + + IndexType indexType = rewriter.getIndexType(); + SmallVector blockOps; + blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps); + + // Step 1. Move the body of foreachThreadOp. + // Erase the terminator first, it will not be used since we are on buffers. + rewriter.eraseOp(foreachThreadOp.getTerminator()); + Block *targetBlock = foreachThreadOp->getBlock(); + Block::iterator insertionPoint = Block::iterator(foreachThreadOp); + Block &sourceBlock = foreachThreadOp.getRegion().front(); + targetBlock->getOperations().splice(insertionPoint, + sourceBlock.getOperations()); + + // Step 2. RAUW thread indices to thread ops. + SmallVector threadIndices = + *foreachThreadOp.getPermutedThreadIndices(); + assert(blockOps.size() == 3 && "3 block id ops are required"); + for (auto it : llvm::zip(threadIndices, blockOps)) { + Value val = std::get<0>(it); + if (!val) + continue; + for (Operation *user : llvm::make_early_inc_range(val.getUsers())) { + rewriter.updateRootInPlace( + user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); }); + } + } + + // Step 3. Erase old op. + rewriter.eraseOp(foreachThreadOp); + + return success(); +} + +FailureOr +mlir::linalg::findTopLevelForeachThreadOp(Operation *target) { + scf::ForeachThreadOp topLevelForeachThreadOp; + auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + if (foreachThreadOp->getParentOfType()) + return WalkResult::advance(); + if (topLevelForeachThreadOp) + // TODO Handle multiple foreach if there is no dependences between them + return WalkResult::interrupt(); + topLevelForeachThreadOp = foreachThreadOp; + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) + return target->emitError( + "could not find a unique topLevel scf.foreach_thread"); + + return topLevelForeachThreadOp; +} + +/// Create gpuLauncOp with given kernel configurations +static FailureOr +createGpuLaunch(RewriterBase &rewriter, Location loc, + Optional gridDimX = llvm::None, + Optional gridDimY = llvm::None, + Optional gridDimZ = llvm::None, + Optional blockDimX = llvm::None, + Optional blockDimY = llvm::None, + Optional blockDimZ = llvm::None) { + if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, + blockDimZ))) + return failure(); + auto createConstant = [&](int dim) { + return rewriter.create(loc, dim); + }; + Value one = createConstant(1); + Value gridSizeX = + gridDimX.has_value() ? createConstant(gridDimX.value()) : one; + Value gridSizeY = + gridDimY.has_value() ? createConstant(gridDimY.value()) : one; + Value gridSizeZ = + gridDimZ.has_value() ? createConstant(gridDimZ.value()) : one; + Value blockSizeX = + blockDimX.has_value() ? createConstant(blockDimX.value()) : one; + Value blockSizeY = + blockDimY.has_value() ? createConstant(blockDimY.value()) : one; + Value blockSizeZ = + blockDimZ.has_value() ? createConstant(blockDimZ.value()) : one; + auto launchOp = rewriter.create( + loc, gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ); + rewriter.setInsertionPointToEnd(&launchOp.body().front()); + rewriter.create(loc); + return launchOp; +} + +DiagnosedSilenceableFailure +transform::MapNestedForeachThreadToGpuBlocks::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + gpu::LaunchOp gpuLaunch = dyn_cast(target); + SimpleRewriter rewriter(getContext()); + + if (!getGenerateGpuLaunch() && !gpuLaunch) { + target->emitError("Given target is not gpu.launch, set " + "`generate_gpu_launch` attribute"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + auto res = mlir::linalg::findTopLevelForeachThreadOp(target); + if (failed(res)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + scf::ForeachThreadOp topLevelForeachThreadOp = *res; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(topLevelForeachThreadOp); + + // Generate gpu launch here and move the foreach_thread inside + if (getGenerateGpuLaunch()) { + FailureOr maybeGpuLaunch = + createGpuLaunch(rewriter, target->getLoc()); + if (failed(maybeGpuLaunch)) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + gpuLaunch = *maybeGpuLaunch; + rewriter.setInsertionPointToStart(&gpuLaunch.body().front()); + Operation *newForeachThreadOp = rewriter.clone(*topLevelForeachThreadOp); + rewriter.eraseOp(topLevelForeachThreadOp); + topLevelForeachThreadOp = + dyn_cast(newForeachThreadOp); + } + + auto generateBlocks = [&](Operation *op, const SmallVector &gridDims, + IndexType indexType, SmallVector &blockOps) { + Location loc = op->getLoc(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, + gpu::Dimension::z}; + for (int64_t idx : llvm::seq(0, gridDims.size())) { + blockOps.push_back( + rewriter.create(loc, indexType, gpuDims[idx])); + } + }; + + SmallVector gridDim = extractFromI64ArrayAttr(getGridDim()); + if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks( + rewriter, topLevelForeachThreadOp, generateBlocks, gridDim))) + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1], + gridDim[2]))) + return DiagnosedSilenceableFailure::definiteFailure(); + + results.assign({gpuLaunch}); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileToForeachThreadOp //===----------------------------------------------------------------------===// @@ -1562,6 +1773,7 @@ declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); + declareGeneratedDialect(); registerTransformOps< #define GET_OP_LIST diff --git a/mlir/test/Dialect/Linalg/transform-gpu.mlir b/mlir/test/Dialect/Linalg/transform-gpu.mlir --- a/mlir/test/Dialect/Linalg/transform-gpu.mlir +++ b/mlir/test/Dialect/Linalg/transform-gpu.mlir @@ -1,4 +1,45 @@ -// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file -canonicalize -cse %s | FileCheck %s + +!type = memref<2 x 32 x f32> +!type1d = memref<32 x f32> + +// CHECK-LABEL: func.func @saxpy2dblock( +// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<2x32xf32> +// CHECK-SAME: %[[ARGT:[0-9a-z]+]]: memref<32xf32> +func.func @saxpy2dblock(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index + %one = arith.constant 1 : index +// CHECK: gpu.launch +// CHECK: %[[BLKX:.*]] = gpu.block_id x +// CHECK: %[[BLKY:.*]] = gpu.block_id y +// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]]] +// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]]] + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j) in (%c7, %c9) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } {thread_dim_mapping = [0, 1, 2]} + gpu.terminator + } + return %y : !type +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { blockDim = [12, 9, 1]} + } +} + +// ----- !type = memref<2 x 32 x f32> !type1d = memref<32 x f32> @@ -12,21 +53,20 @@ %c12 = arith.constant 12 : index %c9 = arith.constant 9 : index %c7 = arith.constant 7 : index -// CHECK: gpu.launch +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C12:.*]] = arith.constant 12 : index +// CHECK: %[[C9:.*]] = arith.constant 9 : index +// CHECK: %[[C7:.*]] = arith.constant 7 : index +// CHECK: gpu.launch async [%{{.*}}] blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C1]], %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C12]], %{{.*}} = %[[C9]], %{{.*}} = %[[C1]]) // CHECK: %[[TIDX:.*]] = gpu.thread_id x // CHECK: %[[TIDY:.*]] = gpu.thread_id y -// CHECK: %[[C9:.*]] = arith.constant 9 : index // CHECK: arith.cmpi ult, %[[TIDX]], %[[C9]] : index -// CHECK: %[[C7:.*]] = arith.constant 7 : index // CHECK: arith.cmpi ult, %[[TIDY]], %[[C7]] : index // CHECK: memref.load %[[ARGX]][%[[TIDY]], %[[TIDX]]] // CHECK: memref.load %[[ARGY]][%[[TIDY]], %[[TIDX]]] // CHECK: gpu.barrier -// CHECK: %[[TIDX2:.*]] = gpu.thread_id x -// CHECK: %[[TIDY2:.*]] = gpu.thread_id y -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: arith.cmpi ult, %[[TIDY2]], %[[C1]] : index -// CHECK: memref.load %[[ARGT]][%[[TIDX2]]] +// CHECK: arith.cmpi ult, %[[TIDY]], %[[C1]] : index +// CHECK: memref.load %[[ARGT]][%[[TIDX]]] // CHECK: gpu.barrier %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) @@ -56,3 +96,45 @@ } } +// ----- + +!type4d = memref<32x64x4x32xf32> + +// CHECK-LABEL: func.func @saxpy4d( +// CHECK-SAME: %[[ARGX:[0-9a-z]+]]: memref<32x64x4x32xf32> +// CHECK-SAME: %[[ARGY:[0-9a-z]+]]: memref<32x64x4x32xf32> +func.func @saxpy4d(%x: !type4d, %y: !type4d, %alpha : f32) -> !type4d { + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c4 = arith.constant 4 : index +// CHECK: %[[C32:.*]] = arith.constant 32 : index +// CHECK: %[[C64:.*]] = arith.constant 64 : index +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C32]], %{{.*}} = %[[C64]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C32]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1]]) +// CHECK: %[[BLKX:.*]] = gpu.block_id x +// CHECK: %[[BLKY:.*]] = gpu.block_id y +// CHECK: %[[TIDX:.*]] = gpu.thread_id x +// CHECK: %[[TIDY:.*]] = gpu.thread_id y +// CHECK: memref.load %[[ARGX]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]] +// CHECK: memref.load %[[ARGY]][%[[BLKX]], %[[BLKY]], %[[TIDY]], %[[TIDX]]] + scf.foreach_thread (%i, %j) in (%c32, %c64) { + scf.foreach_thread (%k, %l) in (%c4, %c32) { + %4 = memref.load %x[%i, %j, %k, %l] : !type4d + %5 = memref.load %y[%i, %j, %k, %l] : !type4d + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j, %k, %l] : !type4d + } {thread_dim_mapping = [1, 0, 2]} + } {thread_dim_mapping = [0, 1, 2]} + return %y : !type4d +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %funcop = transform.structured.match ops{["func.func"]} in %arg0 + %gpuLaunch = transform.structured.map_nested_foreach_thread_to_gpu_blocks %funcop { generate_gpu_launch } + transform.structured.map_nested_foreach_thread_to_gpu_threads %gpuLaunch { blockDim = [32, 4, 1] } + } +}