diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -132,7 +132,7 @@ /// dim sizes are currently not supported. LogicalResult rewriteTopLevelForeachThreadToGpuBlocks( RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, - function_ref &, IndexType, + function_ref &)> blockIdGenerator, SmallVector &gridDims); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1374,7 +1374,7 @@ LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks( RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, - function_ref &, IndexType, + function_ref &)> blockIdGenerator, SmallVector &gridDims) { @@ -1397,9 +1397,8 @@ for (OpFoldResult ofr : *potentialGridDim) gridDims.push_back(getConstantIntValue(ofr).value()); - IndexType indexType = rewriter.getIndexType(); SmallVector blockOps; - blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps); + blockIdGenerator(rewriter, foreachThreadOp, blockOps); // Step 1. Move the body of foreachThreadOp. // Erase the terminator first, it will not be used since we are on buffers. @@ -1485,6 +1484,23 @@ return launchOp; } +/// This is an helper that is only used in +/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id +static void generateGpuBlockIds(RewriterBase &rewriter, + scf::ForeachThreadOp foreachOp, + SmallVector &blockOps) { + Location loc = foreachOp->getLoc(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(foreachOp); + IndexType indexType = rewriter.getIndexType(); + SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, + gpu::Dimension::z}; + for (int64_t idx : llvm::seq(0, gpuDims.size())) { + blockOps.push_back( + rewriter.create(loc, indexType, gpuDims[idx])); + } +} + DiagnosedSilenceableFailure transform::MapNestedForeachThreadToGpuBlocks::applyToOne( Operation *target, SmallVectorImpl &results, @@ -1520,22 +1536,9 @@ dyn_cast(newForeachThreadOp); } - auto generateBlocks = [&](Operation *op, const SmallVector &gridDims, - IndexType indexType, SmallVector &blockOps) { - Location loc = op->getLoc(); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - SmallVector gpuDims{gpu::Dimension::x, gpu::Dimension::y, - gpu::Dimension::z}; - for (int64_t idx : llvm::seq(0, gridDims.size())) { - blockOps.push_back( - rewriter.create(loc, indexType, gpuDims[idx])); - } - }; - SmallVector gridDim = extractFromI64ArrayAttr(getGridDim()); if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks( - rewriter, topLevelForeachThreadOp, generateBlocks, gridDim))) + rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim))) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],