diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h @@ -42,8 +42,11 @@ /// supported. Dynamic block dim sizes are currently not supported. DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl( RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, bool syncAfterDistribute, - std::optional transformOp, + const SmallVectorImpl &blockDim, + function_ref &)> + threadIdGenerator, + bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes); /// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -502,8 +502,11 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl( RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, bool syncAfterDistribute, - std::optional transformOp, + const SmallVectorImpl &blockDim, + function_ref &)> + threadIdGenerator, + bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes) { DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); target->walk([&](scf::ForeachThreadOp foreachThreadOp) { @@ -517,14 +520,8 @@ foreachThreadOp.getMapping(), transformOp); if (diag.succeeded()) { rewriter.setInsertionPoint(foreachThreadOp); - IndexType indexType = rewriter.getIndexType(); - SmallVector threadOps{ - rewriter.create(foreachThreadOp.getLoc(), indexType, - Dimension::x), - rewriter.create(foreachThreadOp.getLoc(), indexType, - Dimension::y), - rewriter.create(foreachThreadOp.getLoc(), indexType, - Dimension::z)}; + SmallVector threadOps; + threadIdGenerator(rewriter, foreachThreadOp, threadOps); diag = rewriteOneForeachThreadToGpuThreads( rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute, transformOp, threadMappingAttributes); @@ -562,10 +559,20 @@ GPUThreadMappingAttr::get(ctx, Threads::DimX), GPUThreadMappingAttr::get(ctx, Threads::DimY), GPUThreadMappingAttr::get(ctx, Threads::DimZ)}; - + auto threadIdGenerator = [](RewriterBase &rewriter, + scf::ForeachThreadOp foreachThreadOp, + SmallVectorImpl &threadIds) { + IndexType indexType = rewriter.getIndexType(); + threadIds.assign({rewriter.create(foreachThreadOp->getLoc(), + indexType, Dimension::x), + rewriter.create(foreachThreadOp->getLoc(), + indexType, Dimension::y), + rewriter.create(foreachThreadOp->getLoc(), + indexType, Dimension::z)}); + }; diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( - rewriter, target, blockDim, getSyncAfterDistribute(), transformOp, - threadMappingAttributes); + rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(), + transformOp, threadMappingAttributes); if (diag.succeeded()) { diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,