diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h @@ -33,6 +33,7 @@ namespace transform { namespace gpu { +<<<<<<< HEAD /// Searches `scf.forall` ops nested under `target` and maps each such /// op to GPU threads. Mapping is one-to-one and the induction variables of /// `scf.forall` are rewritten to gpu.thread_id according to the @@ -41,13 +42,37 @@ /// predication. Dynamic, `scf.forall` trip counts are currently not /// supported. Dynamic block dim sizes are currently not supported. DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl( +======= +/// Map the top level `scf.forall` op to GPU Thread Blocks. +/// Mapping is one-to-one and the induction variables of `scf.forall` are +/// rewritten to gpu.block_id according to the thread_dim_apping attribute. +/// Dynamic, `scf.forall` trip counts are currently not supported. +/// Dynamic block dim sizes are currently not supported. +DiagnosedSilenceableFailure mapForallToBlocksImpl( + RewriterBase &rewriter, scf::ForallOp forallOp, + function_ref &)> + blockIdGenerator, + SmallVectorImpl &gridDims, TransformOpInterface transformOp, + const ArrayRef &mappingAttributes); + +/// Search `scf.forall` ops nested under `target` and map each such op to GPU +/// threads. Mapping is one-to-one and the induction variables of `scf.forall` +/// are rewritten to gpu.thread_id according to the thread_dim_mapping +/// attribute. +/// Sibling `scf.forall` are supported in which case, the union of the number of +/// threads is computed and may result in predication. +/// Dynamic, `scf.forall` trip counts are currently not supported. +/// Dynamic block dim sizes are currently not supported. +DiagnosedSilenceableFailure mapNestedForallToThreadsImpl( +>>>>>>> be8e9d160273 (WIP) RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, + const SmallVectorImpl &kernelBlockDims, function_ref &)> threadIdGenerator, bool syncAfterDistribute, std::optional transformOp, const ArrayRef &threadMappingAttributes); +<<<<<<< HEAD /// Maps the top level `scf.forall` op to GPU Thread Blocks. Mapping is /// one-to-one and the induction variables of `scf.forall` are rewritten /// to gpu.block_id according to the thread_dim_apping attribute. Dynamic, @@ -61,6 +86,9 @@ const ArrayRef &mappingAttributes); /// Finds the top level scf::ForallOp of given target. +======= +/// Find the unique top level scf::ForallOp within a given target op. +>>>>>>> be8e9d160273 (WIP) DiagnosedSilenceableFailure findTopLevelForallOp(Operation *target, scf::ForallOp &topLevelForallOp, TransformOpInterface transformOp); diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -652,15 +652,6 @@ /// Checks if the lbs are zeros and steps are ones. bool isNormalized(); - /// Helper to sort `values` according to matching `keys`. - /// Take a custom `compare` binary comparator which returns true if the first - /// element is smaller than the second (i.e. compatible with std::sort). - /// This is a helper typically used to sort numThreads values before they are - /// mapped to concrete physical dimensions of hardware. - static SmallVector getValuesSortedByKey( - ArrayRef keys, ValueRange values, - llvm::function_ref compare); - // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be // well-formed. We override it here to ensure that we do the right thing. diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -117,6 +117,17 @@ decomposeMixedValues(Builder &b, const SmallVectorImpl &mixedValues); +template +SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, + llvm::function_ref compare); + +SmallVector +getValuesSortedByKey(ArrayRef keys, ArrayRef values, + llvm::function_ref compare); +SmallVector +getValuesSortedByKey(ArrayRef keys, ArrayRef values, + llvm::function_ref compare); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -18,30 +18,131 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformUtils.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" using namespace mlir; using namespace mlir::gpu; using namespace mlir::transform; +namespace { + +using IdGeneratorFnType = llvm::function_ref &)>; + +struct MappingToGpuHelper { + MappingToGpuHelper(SmallVector mappingAttributes, + IdGeneratorFnType idGenerator) + : mappingAttributes(mappingAttributes), idGenerator(idGenerator) {} + + SmallVector mappingAttributes; + IdGeneratorFnType idGenerator; +}; + +struct MappingToGpuBlocksHelper : public MappingToGpuHelper { + + MappingToGpuBlocksHelper(MLIRContext *ctx) + : MappingToGpuHelper( + SmallVector{ + GPUBlockMappingAttr::get(ctx, Blocks::DimX), + GPUBlockMappingAttr::get(ctx, Blocks::DimY), + GPUBlockMappingAttr::get(ctx, Blocks::DimZ)}, + IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp, + SmallVectorImpl &ids) { + auto loc = forallOp->getLoc(); + IndexType indexType = rewriter.getIndexType(); + ids.assign( + {rewriter.create(loc, indexType, Dimension::x), + rewriter.create(loc, indexType, Dimension::y), + rewriter.create(loc, indexType, Dimension::z)}); + }}) {} +}; + +struct MappingToGpuThreadsHelper : public MappingToGpuHelper { + + MappingToGpuThreadsHelper(MLIRContext *ctx) + : MappingToGpuHelper( + SmallVector{ + GPUThreadMappingAttr::get(ctx, Threads::DimX), + GPUThreadMappingAttr::get(ctx, Threads::DimY), + GPUThreadMappingAttr::get(ctx, Threads::DimZ)}, + IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp, + SmallVectorImpl &ids) { + auto loc = forallOp->getLoc(); + IndexType indexType = rewriter.getIndexType(); + ids.assign( + {rewriter.create(loc, indexType, Dimension::x), + rewriter.create(loc, indexType, Dimension::y), + rewriter.create(loc, indexType, Dimension::z)}); + }}) {} +}; + +struct MappingToGpuWarpsHelper : public MappingToGpuHelper { + MappingToGpuWarpsHelper(MLIRContext *ctx) + : MappingToGpuHelper( + SmallVector{ + GPUWarpMappingAttr::get(ctx, Warps::DimX), + GPUWarpMappingAttr::get(ctx, Warps::DimY), + GPUWarpMappingAttr::get(ctx, Warps::DimZ)}, + IdGeneratorFnType{[](RewriterBase &rewriter, scf::ForallOp forallOp, + SmallVectorImpl &ids) { + auto loc = forallOp->getLoc(); + IndexType indexType = rewriter.getIndexType(); + // TODO: compute this + // ids.assign({rewriter.create( + // loc, indexType, Dimension::x), + // rewriter.create( + // loc, indexType, Dimension::y), + // rewriter.create( + // loc, indexType, + // Dimension::z)}); + }}) {} +}; + +} // namespace + +static DiagnosedSilenceableFailure +failureHelper(std::optional transformOp, + scf::ForallOp forallOp, const Twine &message) { + if (transformOp.has_value()) + return transformOp->emitSilenceableError() << message; + return emitDefiniteFailure(forallOp, message); +} + /// Check if given mapping attributes are one of the desired attributes static DiagnosedSilenceableFailure -checkAttributeType(ArrayRef threadMappingAttributes, - const std::optional &foreachMapping, - std::optional transformOp) { - if (!foreachMapping.has_value()) - return transformOp->emitSilenceableError() << "mapping must be present"; +checkMappingAttributeTypes(std::optional transformOp, + scf::ForallOp forallOp) { + if (!forallOp.getMapping().has_value()) + return failureHelper(transformOp, forallOp, "mapping must be present"); + + bool hasBlockMapping = + llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { + return attr.isa(); + }); + bool hasThreadMapping = + llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { + return attr.isa(); + }); + bool hasWarpMapping = + llvm::any_of(forallOp.getMapping().value(), [](Attribute attr) { + return attr.isa(); + }); + int64_t countMappingTypes = 0; + countMappingTypes += hasBlockMapping ? 1 : 0; + countMappingTypes += hasThreadMapping ? 1 : 0; + countMappingTypes += hasWarpMapping ? 1 : 0; + if (countMappingTypes > 1) { + return failureHelper(transformOp, forallOp, + "cannot mix different mapping types, use nesting"); + } DenseSet seen; - for (Attribute map : foreachMapping->getValue()) { - if (!llvm::is_contained(threadMappingAttributes, map)) { - return transformOp->emitDefiniteFailure() - << "mapping must be one of " << threadMappingAttributes; - } + for (Attribute map : forallOp.getMapping()->getValue()) { if (llvm::is_contained(seen, map)) { - return transformOp->emitDefiniteFailure() - << map - << " is duplicated, cannot map different " - "loops to the same processor"; + return failureHelper(transformOp, forallOp, + "duplicated attribute, cannot map different loops " + "to the same processor"); } seen.insert(map); } @@ -49,6 +150,32 @@ return DiagnosedSilenceableFailure::success(); } +static DiagnosedSilenceableFailure +verifyGpuMapping(std::optional transformOp, + scf::ForallOp forallOp) { + // Check the types of the mapping attributes match. + DiagnosedSilenceableFailure typeRes = + checkMappingAttributeTypes(transformOp, forallOp); + if (!typeRes.succeeded()) + return typeRes; + + // Perform other non-types verifications. + if (!forallOp.isNormalized()) + return failureHelper(transformOp, forallOp, + "unsupported non-normalized loops"); + if (forallOp.getNumResults() > 0) + return failureHelper(transformOp, forallOp, + "only bufferized scf.forall can be mapped"); + if (forallOp.getRank() > 3) + return failureHelper(transformOp, forallOp, + "scf.forall with rank > 3 does not lower"); + if (llvm::any_of(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) { + return failureHelper(transformOp, forallOp, "unsupported dynamic sizes"); + } +} + /// Determines if the size of the kernel configuration is supported by the GPU /// architecture being used. It presently makes use of CUDA limitations, however /// that aspect may be enhanced for other GPUs. @@ -170,9 +297,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForeachToBlocksImpl( RewriterBase &rewriter, scf::ForallOp forallOp, - function_ref &)> - blockIdGenerator, - SmallVectorImpl &gridDims, TransformOpInterface transformOp, + IdGeneratorFnType blockIdGenerator, SmallVectorImpl &gridDims, + TransformOpInterface transformOp, const ArrayRef &mappingAttributes) { // Step 0. Target-specific verifications. There is no good place to anchor // those right now: the ForallOp is target-independent and the @@ -222,7 +348,7 @@ for (Value v : gridDimValues) gridDims.push_back(v.getDefiningOp().value()); - // Step 3. Generate the blockIds using the provided generator and map the + // Step 3. Generate the blockids using the provided generator and map the // induction variables to the newly created ops. SmallVector blockOps; blockIdGenerator(rewriter, forallOp, blockOps); @@ -275,21 +401,6 @@ return DiagnosedSilenceableFailure::success(); } -/// This is a helper that is only used in -/// rewriteTopLevelForallToGpuBlocks. It generates GPU dialects -/// block_id. -static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForallOp foreachOp, - SmallVectorImpl &blockOps) { - Location loc = foreachOp->getLoc(); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(foreachOp); - IndexType indexType = rewriter.getIndexType(); - blockOps = SmallVector{ - rewriter.create(loc, indexType, Dimension::x), - rewriter.create(loc, indexType, Dimension::y), - rewriter.create(loc, indexType, Dimension::z)}; -} - DiagnosedSilenceableFailure transform::MapForeachToBlocks::applyToOne(Operation *target, ApplyToEachResultList &results, @@ -331,23 +442,21 @@ topLevelForallOp = cast(newForallOp); } + diag = checkMappingAttributeTypes(transformOp, topLevelForallOp.getMapping()); + if (!diag.succeeded()) + return diag; + + MappingToGpuBlocksHelper helper(getContext()); SmallVector gridDim = extractFromI64ArrayAttr(getGridDim()); - SmallVector blockMappingAttributes = { - GPUBlockMappingAttr::get(getContext(), Blocks::DimX), - GPUBlockMappingAttr::get(getContext(), Blocks::DimY), - GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)}; - - diag = checkAttributeType(blockMappingAttributes, - topLevelForallOp.getMapping(), transformOp); - if (diag.succeeded()) - diag = mlir::transform::gpu::mapForeachToBlocksImpl( - rewriter, topLevelForallOp, generateGpuBlockIds, gridDim, transformOp, - blockMappingAttributes); - if (diag.succeeded()) { - diag = alterGpuLaunch(rewriter, gpuLaunch, - cast(getOperation()), - gridDim[0], gridDim[1], gridDim[2]); - } + diag = mlir::transform::gpu::mapForallToBlocksImpl( + rewriter, topLevelForallOp, helper.idGenerator, gridDim, transformOp, + helper.mappingAttributes); + if (!diag.succeeded()) + return diag; + + diag = alterGpuLaunch(rewriter, gpuLaunch, + cast(getOperation()), gridDim[0], + gridDim[1], gridDim[2]); results.push_back(gpuLaunch); return diag; @@ -357,57 +466,32 @@ // MapNestedForeachToThreads //===----------------------------------------------------------------------===// -/// Searches `scf.forall` ops nested under `target` and maps each such -/// op to GPU threads. Mapping is one-to-one and the induction variables of -/// `scf.forall` are rewritten to gpu.thread_id according to the -/// thread_dim_mapping attribute. Sibling `scf.forall` are supported in -/// which case, the union of the number of threads is computed and may result -/// in predication. Dynamic, `scf.forall` trip counts are currently -/// not supported. Dynamic block dim sizes are currently not supported. static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads( RewriterBase &rewriter, scf::ForallOp forallOp, - const SmallVectorImpl &globalBlockDims, + const SmallVectorImpl &kernelBlockDims, const SmallVectorImpl &threadOps, bool syncAfterDistribute, std::optional transformOp, - const ArrayRef &threadMappingAttributes) { - // Step 0. Target-specific verifications. There is no good place to anchor - // those right now: the ForallOp is target-independent and the - // transform op does not apply to individual ForallOp. - auto failureHelper = - [&](const Twine &message) -> DiagnosedSilenceableFailure { - if (transformOp.has_value()) { - return transformOp->emitSilenceableError() << message; - } - return emitDefiniteFailure(forallOp, message); - }; + const ArrayRef &mappingAttributes) { + + // Step 0. GPU-specific verifications. There is no better place to anchor + // those right now: the ForallOp is target-independent and the transform op + // does not apply to individual ForallOp. + verifyGpuMapping(transformOp, forallOp); + Location loc = forallOp->getLoc(); - if (!forallOp.isNormalized()) - return failureHelper("unsupported non-normalized loops"); - if (forallOp.getNumResults() > 0) - return failureHelper("only bufferized scf.forall lowers to gpu.thread_id"); - if (forallOp.getRank() > 3) - return failureHelper( - "scf.forall with rank > 3 does not lower to gpu.thread_id"); - if (llvm::any_of(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) { - return !getConstantIntValue(ofr).has_value(); - })) { - return failureHelper("unsupported dynamic blockdim size"); - } - if (!forallOp.getMapping().has_value()) - return failureHelper("mapping must be present"); - SmallVector threadMapping = + + SmallVector mapping = llvm::to_vector(forallOp.getMapping()->getValue()); - // Step 1. Complete the threadMapping to a full mapping (with 1s) if + // Step 1. Complete the mapping to a full mapping (with 1s) if // necessary. - SmallVector numThreads = forallOp.getUpperBound(rewriter); - // Ensure we have 3 block sizes, one for each id. - Value one; - for (auto attr : threadMappingAttributes) { - if (std::find(threadMapping.begin(), threadMapping.end(), attr) == - threadMapping.end()) { - threadMapping.push_back(attr); - one = one ? one : rewriter.create(loc, 1); + SmallVector numThreads = forallOp.getMixedUpperBound(); + + // Ensure we have 3 blockDims, one for each id. + Attribute one = rewriter.getIndexAttr(1); + for (auto attr : mappingAttributes) { + if (std::find(mapping.begin(), mapping.end(), attr) == mapping.end()) { + mapping.push_back(attr); numThreads.push_back(one); } } @@ -417,27 +501,28 @@ DeviceMappingAttrInterface b) -> bool { return a.getMappingId() < b.getMappingId(); }; - SmallVector blockDimValues = scf::ForallOp::getValuesSortedByKey( - threadMapping, numThreads, comparator); + SmallVector blockDimValues = + getValuesSortedByKey(mapping, numThreads, comparator); SmallVector blockDims = - llvm::to_vector(llvm::map_range(blockDimValues, [](Value v) { - return v.getDefiningOp().value(); + llvm::to_vector(llvm::map_range(blockDimValues, [](OpFoldResult ofr) { + return getConstantIntValue(ofr).value(); })); // Step 3. Create the gpu.thread ops and map the induction variables to the // newly created ops. // Replace ids of dimension size 1 by zero to simplify the IR. - SmallVector threadOpsUpdated(threadOps.begin(), threadOps.end()); - assert(threadOps.size() == globalBlockDims.size()); - Value zero = rewriter.create(loc, 0); - for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) { - if (globalBlockDims[i] == 1) + SmallVector threadOpsUpdated(threadOps.begin(), + threadOps.end()); + assert(threadOps.size() == kernelBlockDims.size()); + OpFoldResult zero = rewriter.getIndexAttr(0); + for (size_t i : llvm::seq(size_t(0), kernelBlockDims.size())) { + if (kernelBlockDims[i] == 1) threadOpsUpdated[i] = zero; } IRMapping bvm; - for (auto [blockIdx, blockDim] : - llvm::zip(forallOp.getInductionVars(), threadMapping)) { - bvm.map(blockIdx, + for (auto [threadIdx, blockDim] : + llvm::zip(forallOp.getInductionVars(), mapping)) { + bvm.map(threadIdx, threadOpsUpdated[blockDim.cast() .getMappingId()]); } @@ -445,18 +530,19 @@ // Step 4. Maybe create conditionals to predicate the region. Value predicate; for (auto [threadId, blockDim, globalBlockDim] : - llvm::zip(threadOpsUpdated, blockDims, globalBlockDims)) { + llvm::zip(threadOpsUpdated, blockDims, kernelBlockDims)) { if (blockDim > globalBlockDim) { return failureHelper( - "The requested GPU threads are fewer than the number of loop trip " - "counts. Try to tile scf.forall before mapping or set " - "small blockDim."); + "Trying to map to fewer GPU threads than loop iterations but " + "overprovisioning is not yet supported. " + "Try additional tiling of the before mapping or map to more " + "threads."); } if (blockDim == globalBlockDim) continue; - Value blockIdx = rewriter.create(loc, blockDim); + Value threadIdx = rewriter.create(loc, blockDim); Value tmpPredicate = rewriter.create( - loc, arith::CmpIPredicate::ult, threadId, blockIdx); + loc, arith::CmpIPredicate::ult, threadId, threadIdx); predicate = predicate ? rewriter.create(loc, predicate, tmpPredicate) : tmpPredicate; @@ -501,28 +587,26 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl( RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, - function_ref &)> - threadIdGenerator, + const SmallVectorImpl &blockDim, IdGeneratorFnType idGenerator, bool syncAfterDistribute, std::optional transformOp, - const ArrayRef &threadMappingAttributes) { + const ArrayRef &mappingAttributes) { DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); target->walk([&](scf::ForallOp forallOp) { // Ignore cases with different attributes. for (Attribute map : forallOp.getMapping()->getValue()) { - if (!llvm::is_contained(threadMappingAttributes, map)) { + if (!llvm::is_contained(mappingAttributes, map)) { return WalkResult::skip(); } } - diag = checkAttributeType(threadMappingAttributes, forallOp.getMapping(), - transformOp); + diag = checkMappingAttributeTypes(mappingAttributes, forallOp.getMapping(), + transformOp); if (diag.succeeded()) { rewriter.setInsertionPoint(forallOp); SmallVector threadOps; - threadIdGenerator(rewriter, forallOp, threadOps); + idGenerator(rewriter, forallOp, threadOps); diag = rewriteOneForallToGpuThreads(rewriter, forallOp, blockDim, threadOps, syncAfterDistribute, - transformOp, threadMappingAttributes); + transformOp, mappingAttributes); } return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt(); }); @@ -534,48 +618,36 @@ LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); - if (!gpuLaunch) { - return emitSilenceableError() << "Given target is not gpu.launch"; - } + // Basic high-level verifications. + if (!gpuLaunch) + return emitSilenceableError() << "Given target is not a gpu.launch"; SmallVector blockDim = extractFromI64ArrayAttr(getBlockDim()); - blockDim.resize(/*size=*/3, /*value=*/1); + if (blockDim.size() != 3) + return transformOp.emitDefiniteFailure("require blockDims of size 3"); DiagnosedSilenceableFailure diag = checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, blockDim[0], blockDim[1], blockDim[2]); if (diag.isSilenceableFailure()) { - diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large"; + diag.attachNote(getLoc()) << getBlockDimAttrName() << " is too large"; return diag; } MLIRContext *ctx = getContext(); TrivialPatternRewriter rewriter(ctx); rewriter.setInsertionPoint(target); + MappingToGpuThreadsHelper helper(ctx); + diag = mlir::transform::gpu::mapNestedForallToThreadsImpl( + rewriter, target, blockDim, helper.idGenerator, getSyncAfterDistribute(), + transformOp, helper.mappingAttributes); - SmallVector threadMappingAttributes = { - GPUThreadMappingAttr::get(ctx, Threads::DimX), - GPUThreadMappingAttr::get(ctx, Threads::DimY), - GPUThreadMappingAttr::get(ctx, Threads::DimZ)}; - auto threadIdGenerator = [](RewriterBase &rewriter, scf::ForallOp forallOp, - SmallVectorImpl &threadIds) { - IndexType indexType = rewriter.getIndexType(); - threadIds.assign({rewriter.create(forallOp->getLoc(), indexType, - Dimension::x), - rewriter.create(forallOp->getLoc(), indexType, - Dimension::y), - rewriter.create(forallOp->getLoc(), indexType, - Dimension::z)}); - }; - diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( - rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(), - transformOp, threadMappingAttributes); - - if (diag.succeeded()) { - diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, - std::nullopt, std::nullopt, blockDim[0], blockDim[1], - blockDim[2]); - } + if (!diag.succeeded()) + return diag; + + diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, + std::nullopt, std::nullopt, blockDim[0], blockDim[1], + blockDim[2]); results.push_back(gpuLaunch.getOperation()); return diag; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1391,23 +1391,6 @@ return cast(getBody()->getTerminator()); } -/// Helper to sort `values` according to matching `keys`. -SmallVector ForallOp::getValuesSortedByKey( - ArrayRef keys, ValueRange values, - llvm::function_ref compare) { - if (keys.empty()) - return values; - assert(keys.size() == values.size() && "unexpected mismatching sizes"); - auto indices = llvm::to_vector(llvm::seq(0, values.size())); - std::sort(indices.begin(), indices.end(), - [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); - SmallVector res; - res.reserve(values.size()); - for (int64_t i = 0, e = indices.size(); i < e; ++i) - res.push_back(values[indices[i]]); - return res; -} - ForallOp mlir::scf::getForallOpThreadIndexOwner(Value val) { auto tidxArg = val.dyn_cast(); if (!tidxArg) diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -192,4 +192,21 @@ return {b.getI64ArrayAttr(staticValues), dynamicValues}; } +/// Helper to sort `values` according to matching `keys`. +template +SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, + llvm::function_ref compare) { + if (keys.empty()) + return values; + assert(keys.size() == values.size() && "unexpected mismatching sizes"); + auto indices = llvm::to_vector(llvm::seq(0, values.size())); + std::sort(indices.begin(), indices.end(), + [&](int64_t i, int64_t j) { return compare(keys[i], keys[j]); }); + SmallVector res; + res.reserve(values.size()); + for (int64_t i = 0, e = indices.size(); i < e; ++i) + res.push_back(values[indices[i]]); + return res; +} + } // namespace mlir