diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h @@ -39,11 +39,11 @@ /// Dynamic, `scf.forall` trip counts are currently not supported. /// Dynamic block dim sizes are currently not supported. DiagnosedSilenceableFailure mapForallToBlocksImpl( - RewriterBase &rewriter, scf::ForallOp forallOp, + RewriterBase &rewriter, TransformOpInterface transformOp, + scf::ForallOp forallOp, SmallVectorImpl &gridDims, + const ArrayRef &mappingAttributes, function_ref &)> - blockIdGenerator, - SmallVectorImpl &gridDims, TransformOpInterface transformOp, - const ArrayRef &mappingAttributes); + blockIdGenerator); /// Search `scf.forall` ops nested under `target` and map each such op to GPU /// threads. Mapping is one-to-one and the induction variables of `scf.forall` @@ -54,12 +54,12 @@ /// Dynamic, `scf.forall` trip counts are currently not supported. /// Dynamic block dim sizes are currently not supported. DiagnosedSilenceableFailure mapNestedForallToThreadsImpl( - RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &kernelBlockDims, + RewriterBase &rewriter, std::optional transformOp, + Operation *target, const SmallVectorImpl &kernelBlockDims, + bool syncAfterDistribute, + const ArrayRef &threadMappingAttributes, function_ref &)> - threadIdGenerator, - bool syncAfterDistribute, std::optional transformOp, - const ArrayRef &threadMappingAttributes); + threadIdGenerator); /// Find the unique top level scf::ForallOp within a given target op. DiagnosedSilenceableFailure diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -124,6 +124,9 @@ SmallVector getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare); +SmallVector +getValuesSortedByKey(ArrayRef keys, ArrayRef values, + llvm::function_ref compare); } // namespace mlir diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -16,17 +16,26 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" using namespace mlir; using namespace mlir::gpu; using namespace mlir::transform; +#define DEBUG_TYPE "gpu-transforms" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + namespace { -/// Helper type forfunctions that generate ids for the mapping of a scf.forall. +/// Helper type for functions that generate ids for the mapping of a scf.forall. using IdGeneratorFnType = llvm::function_ref &)>; @@ -86,7 +95,7 @@ failureHelper(std::optional transformOp, scf::ForallOp forallOp, const Twine &message) { if (transformOp.has_value()) - return transformOp->emitSilenceableError() << message; + return emitDefiniteFailure(*transformOp, message); return emitDefiniteFailure(forallOp, message); } @@ -273,30 +282,35 @@ // MapForallToBlocks //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( - RewriterBase &rewriter, scf::ForallOp forallOp, - IdGeneratorFnType blockIdGenerator, SmallVectorImpl &gridDims, - TransformOpInterface transformOp, - const ArrayRef &mappingAttributes) { +static FailureOr> rewriteOneForallCommonImpl( + RewriterBase &rewriter, std::optional transformOp, + scf::ForallOp forallOp, + const SmallVectorImpl &availableMappingSizes, + const ArrayRef &allMappingAttributes, + IdGeneratorFnType idGenerator) { + LDBG("Start rewriteOneForallCommonImpl"); // Step 0. GPU-specific verifications. There is no better place to anchor // those right now: the ForallOp is target-independent and the transform op // does not apply to individual ForallOp. DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp); if (!diag.succeeded()) - return diag; - - SmallVector blockMapping = + return failure(); + + // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. + SmallVector tmpMappingSizes = llvm::to_vector( + llvm::map_range(forallOp.getMixedUpperBound(), [](OpFoldResult ofr) { + auto maybeStaticValue = getConstantIntValue(ofr); + assert(maybeStaticValue && "expected static value"); + return maybeStaticValue.value(); + })); + SmallVector forallMappings = llvm::to_vector(forallOp.getMapping()->getValue()); - - // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary. - SmallVector numBlocks = forallOp.getMixedUpperBound(); - // Ensure we have 3 block sizes, one for each id. - for (auto attr : mappingAttributes) { - if (!llvm::is_contained(blockMapping, attr)) { - blockMapping.push_back(attr); - numBlocks.push_back(rewriter.getIndexAttr(1)); - } + for (auto attr : allMappingAttributes) { + if (llvm::is_contained(forallMappings, attr)) + continue; + forallMappings.push_back(attr); + tmpMappingSizes.push_back(1); } // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. @@ -304,43 +318,116 @@ DeviceMappingAttrInterface b) -> bool { return a.getMappingId() < b.getMappingId(); }; - SmallVector gridDimValues = - getValuesSortedByKey(blockMapping, numBlocks, comparator); - gridDims = - llvm::to_vector(llvm::map_range(gridDimValues, [](OpFoldResult ofr) { - return getConstantIntValue(ofr).value(); - })); + SmallVector mappingSizes = + getValuesSortedByKey(forallMappings, tmpMappingSizes, comparator); + LLVM_DEBUG(llvm::interleaveComma(mappingSizes, DBGS() << "mappingSizes: "); + llvm::dbgs() << "\n"; + llvm::interleaveComma(forallMappings, DBGS() << "mappingAttrs: "); + llvm::dbgs() << "\n"); + + // Step 3. Generate the mappingIdOps using the provided generator and map the + // induction variables to the newly created ops. Replace ids of dimension + // known to be of size 1 by zero to simplify the IR. + SmallVector mappingIdOps; + Location loc = forallOp.getLoc(); + idGenerator(rewriter, forallOp, mappingIdOps); + LLVM_DEBUG(llvm::interleaveComma(mappingIdOps, DBGS() << "mappingIdOps: "); + llvm::dbgs() << "\n"); + assert(mappingIdOps.size() == mappingSizes.size() && "expect equal sizes"); + Value zero = rewriter.create(loc, 0); + if (!availableMappingSizes.empty()) { + for (size_t i : llvm::seq(size_t(0), availableMappingSizes.size())) { + if (availableMappingSizes[i] == 1) + mappingIdOps[i] = zero; + } + } - // Step 3. Generate the blockids using the provided generator and map the - // induction variables to the newly created ops. - SmallVector blockOps; - blockIdGenerator(rewriter, forallOp, blockOps); IRMapping bvm; - for (auto [blockIdx, blockDim] : - llvm::zip(forallOp.getInductionVars(), blockMapping)) { - bvm.map(blockIdx, - blockOps[static_cast( - blockDim.cast().getMappingId())]); + for (auto [iv, dim] : + llvm::zip_equal(forallOp.getInductionVars(), + ArrayRef{forallMappings}.take_front( + forallOp.getInductionVars().size()))) { + Value peIdOp = mappingIdOps[static_cast( + dim.cast().getMappingId())]; + bvm.map(iv, peIdOp); } - // Step 4. Move the body of forallOp. - // Erase the terminator first, it will not be used since we are on buffers. + // Step 4. Maybe create conditionals to predicate the region. + // Skip this step when availableMappingSizes is empty. + Value predicate; + if (!availableMappingSizes.empty()) { + LLVM_DEBUG(llvm::interleaveComma(availableMappingSizes, + DBGS() << "availableMappingSizes: "); + llvm::dbgs() << "\n"); + for (auto [id, mappingSize, availableMappingSize] : + llvm::zip_equal(mappingIdOps, mappingSizes, availableMappingSizes)) { + if (mappingSize > availableMappingSize) { + (void)failureHelper( + transformOp, forallOp, + "Trying to map to fewer GPU threads than loop iterations but " + "overprovisioning is not yet supported. " + "Try additional tiling of the before mapping or map to more " + "threads."); + return failure(); + } + if (mappingSize == availableMappingSize) + continue; + Value idx = rewriter.create(loc, mappingSize); + Value tmpPredicate = rewriter.create( + loc, arith::CmpIPredicate::ult, id, idx); + LDBG("predicate: " << tmpPredicate); + predicate = predicate ? rewriter.create(loc, predicate, + tmpPredicate) + : tmpPredicate; + } + } + + // Step 5. Move the body of forallOp. + // Erase the terminator first, it will not be used. rewriter.eraseOp(forallOp.getTerminator()); - Block *targetBlock = forallOp->getBlock(); - Block::iterator insertionPoint = Block::iterator(forallOp); + Block *targetBlock; + Block::iterator insertionPoint; + if (predicate) { + // Step 5.a. If predicated, move at the beginning. + auto ifOp = + rewriter.create(loc, predicate, /*withElseRegion=*/false); + targetBlock = ifOp.thenBlock(); + insertionPoint = ifOp.thenBlock()->begin(); + } else { + // Step 5.b. Otherwise, move inline just at the rewriter insertion point. + targetBlock = forallOp->getBlock(); + insertionPoint = rewriter.getInsertionPoint(); + } Block &sourceBlock = forallOp.getRegion().front(); targetBlock->getOperations().splice(insertionPoint, sourceBlock.getOperations()); - // Step 5. RAUW thread indices to thread ops. + // Step 6. RAUW thread indices to thread ops. for (Value loopIndex : forallOp.getInductionVars()) { - Value blockIdx = bvm.lookup(loopIndex); - rewriter.replaceAllUsesWith(loopIndex, blockIdx); + Value threadIdx = bvm.lookup(loopIndex); + rewriter.replaceAllUsesWith(loopIndex, threadIdx); } - // Step 6. Erase old op. + // Step 7. Erase old op. rewriter.eraseOp(forallOp); + return mappingSizes; +} + +DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( + RewriterBase &rewriter, TransformOpInterface transformOp, + scf::ForallOp forallOp, SmallVectorImpl &gridDims, + const ArrayRef &allMappingAttributes, + IdGeneratorFnType idGenerator) { + // Pass an empty anyAvailableMappingSizes. + SmallVector anyAvailableMappingSizes; + FailureOr> maybeMappingSizes = + rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, + anyAvailableMappingSizes, allMappingAttributes, + idGenerator); + if (failed(maybeMappingSizes)) + return DiagnosedSilenceableFailure::definiteFailure(); + gridDims = *maybeMappingSizes; return DiagnosedSilenceableFailure::success(); } @@ -389,8 +476,8 @@ return diag; } - SmallVector gridDim = extractFromI64ArrayAttr(getGridDim()); - if (!getGenerateGpuLaunch() && gridDim.size() != 3) + SmallVector gridDims = extractFromI64ArrayAttr(getGridDim()); + if (!getGenerateGpuLaunch() && gridDims.size() != 3) return transformOp.emitDefiniteFailure("transform require size-3 mapping"); OpBuilder::InsertionGuard guard(rewriter); @@ -415,14 +502,14 @@ MappingToGpuBlocksHelper helper(getContext()); diag = mlir::transform::gpu::mapForallToBlocksImpl( - rewriter, topLevelForallOp, helper.idGenerator, gridDim, transformOp, - helper.mappingAttributes); + rewriter, transformOp, topLevelForallOp, gridDims, + helper.mappingAttributes, helper.idGenerator); if (!diag.succeeded()) return diag; diag = alterGpuLaunch(rewriter, gpuLaunch, - cast(getOperation()), gridDim[0], - gridDim[1], gridDim[2]); + cast(getOperation()), gridDims[0], + gridDims[1], gridDims[2]); results.push_back(gpuLaunch); return diag; @@ -432,147 +519,33 @@ // MapNestedForallToThreads //===----------------------------------------------------------------------===// -static DiagnosedSilenceableFailure rewriteOneForallToGpuThreads( - RewriterBase &rewriter, scf::ForallOp forallOp, - const SmallVectorImpl &kernelBlockDims, - const SmallVectorImpl &threadOps, bool syncAfterDistribute, - std::optional transformOp, - const ArrayRef &mappingAttributes) { - - // Step 0. GPU-specific verifications. There is no better place to anchor - // those right now: the ForallOp is target-independent and the transform op - // does not apply to individual ForallOp. - DiagnosedSilenceableFailure diag = verifyGpuMapping(transformOp, forallOp); - if (!diag.succeeded()) - return diag; - - Location loc = forallOp->getLoc(); - - SmallVector mapping = - llvm::to_vector(forallOp.getMapping()->getValue()); - - // Step 1. Complete the mapping to a full mapping (with 1s) if - // necessary. - SmallVector numThreads = forallOp.getMixedUpperBound(); - Attribute one = rewriter.getIndexAttr(1); - for (auto attr : mappingAttributes) { - if (std::find(mapping.begin(), mapping.end(), attr) == mapping.end()) { - mapping.push_back(attr); - numThreads.push_back(one); - } - } - - // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. - auto comparator = [&](DeviceMappingAttrInterface a, - DeviceMappingAttrInterface b) -> bool { - return a.getMappingId() < b.getMappingId(); - }; - SmallVector blockDimValues = - getValuesSortedByKey(mapping, numThreads, comparator); - SmallVector blockDims = - llvm::to_vector(llvm::map_range(blockDimValues, [](OpFoldResult ofr) { - return getConstantIntValue(ofr).value(); - })); - - // Step 3. Create the gpu.thread ops and map the induction variables to the - // newly created ops. - // Replace ids of dimension size 1 by zero to simplify the IR. - // TODO - SmallVector threadOpsUpdated(threadOps.begin(), threadOps.end()); - assert(threadOps.size() == kernelBlockDims.size()); - Value zero = rewriter.create(loc, 0); - for (size_t i : llvm::seq(size_t(0), kernelBlockDims.size())) { - if (kernelBlockDims[i] == 1) - threadOpsUpdated[i] = zero; - } - IRMapping bvm; - for (auto [threadIdx, blockDim] : - llvm::zip(forallOp.getInductionVars(), mapping)) { - bvm.map(threadIdx, - threadOpsUpdated[blockDim.cast() - .getMappingId()]); - } - - // Step 4. Maybe create conditionals to predicate the region. - Value predicate; - for (auto [threadId, blockDim, globalBlockDim] : - llvm::zip(threadOpsUpdated, blockDims, kernelBlockDims)) { - if (blockDim > globalBlockDim) { - return failureHelper( - transformOp, forallOp, - "Trying to map to fewer GPU threads than loop iterations but " - "overprovisioning is not yet supported. " - "Try additional tiling of the before mapping or map to more " - "threads."); - } - if (blockDim == globalBlockDim) - continue; - Value threadIdx = rewriter.create(loc, blockDim); - Value tmpPredicate = rewriter.create( - loc, arith::CmpIPredicate::ult, threadId, threadIdx); - predicate = - predicate ? rewriter.create(loc, predicate, tmpPredicate) - : tmpPredicate; - } - - // Step 5. Move the body of forallOp. - // Erase the terminator first, it will not be used. - rewriter.eraseOp(forallOp.getTerminator()); - Block *targetBlock; - Block::iterator insertionPoint; - if (predicate) { - // Step 5.a. If predicated, move at the beginning. - auto ifOp = - rewriter.create(loc, predicate, /*withElseRegion=*/false); - targetBlock = ifOp.thenBlock(); - insertionPoint = ifOp.thenBlock()->begin(); - } else { - // Step 5.b. Otherwise, move inline just before forallOp. - targetBlock = forallOp->getBlock(); - insertionPoint = Block::iterator(forallOp); - } - Block &sourceBlock = forallOp.getRegion().front(); - targetBlock->getOperations().splice(insertionPoint, - sourceBlock.getOperations()); - - // Step 6. RAUW thread indices to thread ops. - for (Value loopIndex : forallOp.getInductionVars()) { - Value threadIdx = bvm.lookup(loopIndex); - rewriter.replaceAllUsesWith(loopIndex, threadIdx); - } - - // Step 7. syncthreads. - // TODO: Need warpsync - if (syncAfterDistribute) - rewriter.create(loc); - - // Step 8. Erase old op. - rewriter.eraseOp(forallOp); - - return DiagnosedSilenceableFailure::success(); -} - DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( - RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, IdGeneratorFnType idGenerator, - bool syncAfterDistribute, std::optional transformOp, - const ArrayRef &mappingAttributes) { + RewriterBase &rewriter, std::optional transformOp, + Operation *target, const SmallVectorImpl &kernelBlockDims, + bool syncAfterDistribute, + const ArrayRef &allMappingAttributes, + IdGeneratorFnType idGenerator) { DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); target->walk([&](scf::ForallOp forallOp) { // Ignore cases with different attributes. for (Attribute map : forallOp.getMapping()->getValue()) { - if (!llvm::is_contained(mappingAttributes, map)) { + if (!llvm::is_contained(allMappingAttributes, map)) { return WalkResult::skip(); } } diag = verifyGpuMapping(transformOp, forallOp); if (diag.succeeded()) { - rewriter.setInsertionPoint(forallOp); - SmallVector threadOps; - idGenerator(rewriter, forallOp, threadOps); - diag = rewriteOneForallToGpuThreads(rewriter, forallOp, blockDim, - threadOps, syncAfterDistribute, - transformOp, mappingAttributes); + // Take the loc ahead of time + Location loc = forallOp.getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forallOp); + if (failed(rewriteOneForallCommonImpl(rewriter, transformOp, forallOp, + kernelBlockDims, + allMappingAttributes, idGenerator))) + diag = DiagnosedSilenceableFailure::definiteFailure(); + // Add a syncthreads if needed. TODO: warpsync + if (syncAfterDistribute) + rewriter.create(loc); } return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt(); }); @@ -588,13 +561,13 @@ if (!gpuLaunch) return emitSilenceableError() << "Given target is not a gpu.launch"; - SmallVector blockDim = extractFromI64ArrayAttr(getBlockDim()); - if (blockDim.size() != 3) + SmallVector blockDims = extractFromI64ArrayAttr(getBlockDim()); + if (blockDims.size() != 3) return transformOp.emitDefiniteFailure("transform require size-3 mapping"); DiagnosedSilenceableFailure diag = checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, - blockDim[0], blockDim[1], blockDim[2]); + blockDims[0], blockDims[1], blockDims[2]); if (diag.isSilenceableFailure()) { diag.attachNote(getLoc()) << getBlockDimAttrName() << " is too large"; return diag; @@ -602,18 +575,17 @@ MLIRContext *ctx = getContext(); IRRewriter rewriter(ctx); - rewriter.setInsertionPoint(target); MappingToGpuThreadsHelper helper(ctx); diag = mlir::transform::gpu::mapNestedForallToThreadsImpl( - rewriter, target, blockDim, helper.idGenerator, getSyncAfterDistribute(), - transformOp, helper.mappingAttributes); + rewriter, transformOp, target, blockDims, getSyncAfterDistribute(), + helper.mappingAttributes, helper.idGenerator); if (!diag.succeeded()) return diag; diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, - std::nullopt, std::nullopt, blockDim[0], blockDim[1], - blockDim[2]); + std::nullopt, std::nullopt, blockDims[0], blockDims[1], + blockDims[2]); results.push_back(gpuLaunch.getOperation()); return diag; diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -222,4 +222,10 @@ return getValuesSortedByKeyImpl(keys, values, compare); } +SmallVector +getValuesSortedByKey(ArrayRef keys, ArrayRef values, + llvm::function_ref compare) { + return getValuesSortedByKeyImpl(keys, values, compare); +} + } // namespace mlir