Index: mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h =================================================================== --- mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h +++ mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h @@ -42,7 +42,8 @@ /// supported. Dynamic block dim sizes are currently not supported. DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl( RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, bool syncAfterDistribute, + const SmallVectorImpl &blockDim, + const SmallVectorImpl &threadOps, bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes); Index: mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp =================================================================== --- mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -502,7 +502,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl( RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, bool syncAfterDistribute, + const SmallVectorImpl &blockDim, + const SmallVectorImpl &threadOps, bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes) { DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); @@ -517,14 +518,6 @@ foreachThreadOp.getMapping(), transformOp); if (diag.succeeded()) { rewriter.setInsertionPoint(foreachThreadOp); - IndexType indexType = rewriter.getIndexType(); - SmallVector threadOps{ - rewriter.create(foreachThreadOp.getLoc(), indexType, - Dimension::x), - rewriter.create(foreachThreadOp.getLoc(), indexType, - Dimension::y), - rewriter.create(foreachThreadOp.getLoc(), indexType, - Dimension::z)}; diag = rewriteOneForeachThreadToGpuThreads( rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute, transformOp, threadMappingAttributes); @@ -562,10 +555,15 @@ GPUThreadMappingAttr::get(ctx, Threads::DimX), GPUThreadMappingAttr::get(ctx, Threads::DimY), GPUThreadMappingAttr::get(ctx, Threads::DimZ)}; + IndexType indexType = rewriter.getIndexType(); + SmallVector threadOps{ + rewriter.create(target->getLoc(), indexType, Dimension::x), + rewriter.create(target->getLoc(), indexType, Dimension::y), + rewriter.create(target->getLoc(), indexType, Dimension::z)}; diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( - rewriter, target, blockDim, getSyncAfterDistribute(), transformOp, - threadMappingAttributes); + rewriter, target, blockDim, threadOps, getSyncAfterDistribute(), + transformOp, threadMappingAttributes); if (diag.succeeded()) { diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,