Index: mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp =================================================================== --- mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -366,7 +366,8 @@ /// not supported. Dynamic block dim sizes are currently not supported. static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, - const SmallVectorImpl &globalBlockDims, bool syncAfterDistribute, + const SmallVectorImpl &globalBlockDims, + const SmallVectorImpl &threadOps, bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes) { // Step 0. Target-specific verifications. There is no good place to anchor @@ -427,28 +428,26 @@ // Step 3. Create the gpu.thread ops and map the induction variables to the // newly created ops. IndexType indexType = rewriter.getIndexType(); - SmallVector threadOps{ - rewriter.create(loc, indexType, Dimension::x), - rewriter.create(loc, indexType, Dimension::y), - rewriter.create(loc, indexType, Dimension::z)}; // Replace ids of dimension size 1 by zero to simplify the IR. + SmallVector threadOpsUpdated(threadOps.begin(), threadOps.end()); + assert(threadOps.size() == globalBlockDims.size()); Value zero = rewriter.create(loc, 0); for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) { if (globalBlockDims[i] == 1) - threadOps[i] = zero; + threadOpsUpdated[i] = zero; } IRMapping bvm; for (auto [blockIdx, blockDim] : llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) { - bvm.map( - blockIdx, - threadOps[blockDim.cast().getMappingId()]); + bvm.map(blockIdx, + threadOpsUpdated[blockDim.cast() + .getMappingId()]); } // Step 4. Maybe create conditionals to predicate the region. Value predicate; for (auto [threadId, blockDim, globalBlockDim] : - llvm::zip(threadOps, blockDims, globalBlockDims)) { + llvm::zip(threadOpsUpdated, blockDims, globalBlockDims)) { if (blockDim > globalBlockDim) { return failureHelper( "The requested GPU threads are fewer than the number of loop trip " @@ -519,9 +518,17 @@ foreachThreadOp.getMapping(), transformOp); if (diag.succeeded()) { rewriter.setInsertionPoint(foreachThreadOp); + IndexType indexType = rewriter.getIndexType(); + SmallVector threadOps{ + rewriter.create(foreachThreadOp.getLoc(), indexType, + Dimension::x), + rewriter.create(foreachThreadOp.getLoc(), indexType, + Dimension::y), + rewriter.create(foreachThreadOp.getLoc(), indexType, + Dimension::z)}; diag = rewriteOneForeachThreadToGpuThreads( - rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp, - threadMappingAttributes); + rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute, + transformOp, threadMappingAttributes); } return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt(); });