diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td @@ -27,7 +27,8 @@ } def GPUThreadMappingAttr - : GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> { + : GPU_Attr<"GPUThreadMapping", "thread", [ + DeclareAttrInterfaceMethods ]> { let parameters = (ins EnumParameter:$thread ); @@ -47,7 +48,8 @@ let cppNamespace = "::mlir::gpu"; } -def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] > { +def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ + DeclareAttrInterfaceMethods ] > { let parameters = (ins EnumParameter:$block ); diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h @@ -40,11 +40,11 @@ /// which case, the union of the number of threads is computed and may result in /// predication. Dynamic, `scf.foreach_thread` trip counts are currently not /// supported. Dynamic block dim sizes are currently not supported. -DiagnosedSilenceableFailure -mapNestedForeachToThreadsImpl(RewriterBase &rewriter, Operation *target, - const SmallVectorImpl &blockDim, - bool syncAfterDistribute, - llvm::Optional transformOp); +DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl( + RewriterBase &rewriter, Operation *target, + const SmallVectorImpl &blockDim, bool syncAfterDistribute, + llvm::Optional transformOp, + const ArrayRef &threadMappingAttributes); /// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is /// one-to-one and the induction variables of `scf.foreach_thread` are rewritten @@ -56,7 +56,8 @@ function_ref &)> blockIdGenerator, - SmallVectorImpl &gridDims, TransformOpInterface transformOp); + SmallVectorImpl &gridDims, TransformOpInterface transformOp, + const ArrayRef &mappingAttributes); /// Finds the top level scf::ForeachThreadOp of given target. DiagnosedSilenceableFailure diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td --- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td +++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td @@ -34,6 +34,14 @@ of the loops it contains to the GPU's parallelism units such as threads and thread blocks. }]; + + let methods = [ + InterfaceMethod<[{ + Returns mapping as an integer from the attribute. + }], + "int64_t", "getMappingId", (ins) + > + ]; } def DeviceMappingArrayAttr : diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -33,6 +33,18 @@ #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" +//===----------------------------------------------------------------------===// +// GPU Device Mapping Attributes +//===----------------------------------------------------------------------===// + +int64_t GPUBlockMappingAttr::getMappingId() const { + return static_cast(getBlock()); +} + +int64_t GPUThreadMappingAttr::getMappingId() const { + return static_cast(getThread()); +} + //===----------------------------------------------------------------------===// // MMAMatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" #include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" @@ -33,6 +34,24 @@ }; } // namespace +/// Check if given mapping attributes are one of the desired attributes +static DiagnosedSilenceableFailure checkAttributeType( + const ArrayRef &threadMappingAttributes, + const Optional &foreachMapping, + llvm::Optional transformOp) { + if (!foreachMapping.has_value()) + return transformOp->emitSilenceableError() << "mapping must be present"; + + if (llvm::any_of(foreachMapping->getValue(), + [&](DeviceMappingAttrInterface map) { + return llvm::find(threadMappingAttributes, map) == + threadMappingAttributes.end(); + })) + return transformOp->emitDefiniteFailure() + << "mapping must be one of " << threadMappingAttributes; + return DiagnosedSilenceableFailure::success(); +} + /// Determines if the size of the kernel configuration is supported by the GPU /// architecture being used. It presently makes use of CUDA limitations, however /// that aspect may be enhanced for other GPUs. @@ -157,15 +176,13 @@ function_ref &)> blockIdGenerator, - SmallVectorImpl &gridDims, TransformOpInterface transformOp) { + SmallVectorImpl &gridDims, TransformOpInterface transformOp, + const ArrayRef &mappingAttributes) { // Step 0. Target-specific verifications. There is no good place to anchor // those right now: the ForeachThreadOp is target-independent and the // transform op does not apply to individual ForeachThreadOp. - MLIRContext *ctx = foreachThreadOp->getContext(); Location loc = foreachThreadOp->getLoc(); - Attribute bX = GPUBlockMappingAttr::get(ctx, Blocks::DimX); - Attribute bY = GPUBlockMappingAttr::get(ctx, Blocks::DimY); - Attribute bZ = GPUBlockMappingAttr::get(ctx, Blocks::DimZ); + if (foreachThreadOp.getNumResults() > 0) return transformOp.emitSilenceableError() << "only bufferized scf.foreach_thread lowers to " @@ -180,23 +197,15 @@ return transformOp.emitSilenceableError() << "unsupported dynamic griddim size"; } - if (!foreachThreadOp.getMapping().has_value()) - return transformOp.emitSilenceableError() << "mapping must be present"; SmallVector blockMapping = llvm::to_vector(foreachThreadOp.getMapping()->getValue()); - if (llvm::any_of(blockMapping, [](DeviceMappingAttrInterface map) { - return !map.isa(); - })) { - return transformOp.emitSilenceableError() - << "mapping must be #gpu.block"; - } // Step 1. Complete the blockMapping to a full mapping (with 1s) if necessary. SmallVector numBlocks = llvm::to_vector(foreachThreadOp.getNumThreads()); // Ensure we have 3 block sizes, one for each id. Value one; - for (auto attr : {bX, bY, bZ}) { + for (auto attr : mappingAttributes) { if (std::find(blockMapping.begin(), blockMapping.end(), attr) == blockMapping.end()) { blockMapping.push_back(attr); @@ -205,10 +214,10 @@ } } - // Step 2. sort the values by the corresponding GPUBlockMappingAttr. - auto comparator = [](Attribute a, Attribute b) -> bool { - return static_cast(a.cast().getBlock()) < - static_cast(b.cast().getBlock()); + // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. + auto comparator = [&](DeviceMappingAttrInterface a, + DeviceMappingAttrInterface b) -> bool { + return a.getMappingId() < b.getMappingId(); }; SmallVector gridDimValues = scf::ForeachThreadOp::getValuesSortedByKey( blockMapping, numBlocks, comparator); @@ -222,8 +231,9 @@ BlockAndValueMapping bvm; for (auto [blockIdx, blockDim] : llvm::zip(foreachThreadOp.getThreadIndices(), blockMapping)) { - bvm.map(blockIdx, blockOps[static_cast( - blockDim.cast().getBlock())]); + bvm.map(blockIdx, + blockOps[static_cast( + blockDim.cast().getMappingId())]); } // Step 4. Move the body of foreachThreadOp. @@ -331,9 +341,17 @@ } SmallVector gridDim = extractFromI64ArrayAttr(getGridDim()); - diag = mlir::transform::gpu::mapForeachToBlocksImpl( - rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim, - transformOp); + SmallVector blockMappingAttributes = { + GPUBlockMappingAttr::get(getContext(), Blocks::DimX), + GPUBlockMappingAttr::get(getContext(), Blocks::DimY), + GPUBlockMappingAttr::get(getContext(), Blocks::DimZ)}; + + diag = checkAttributeType(blockMappingAttributes, + topLevelForeachThreadOp.getMapping(), transformOp); + if (diag.succeeded()) + diag = mlir::transform::gpu::mapForeachToBlocksImpl( + rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim, + transformOp, blockMappingAttributes); if (diag.succeeded()) { diag = alterGpuLaunch(rewriter, gpuLaunch, cast(getOperation()), @@ -358,7 +376,8 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads( RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp, const SmallVectorImpl &globalBlockDims, bool syncAfterDistribute, - llvm::Optional transformOp) { + llvm::Optional transformOp, + const ArrayRef &threadMappingAttributes) { // Step 0. Target-specific verifications. There is no good place to anchor // those right now: the ForeachThreadOp is target-independent and the // transform op does not apply to individual ForeachThreadOp. @@ -369,11 +388,7 @@ } return emitDefiniteFailure(foreachThreadOp, message); }; - MLIRContext *ctx = foreachThreadOp->getContext(); Location loc = foreachThreadOp->getLoc(); - Attribute tX = GPUThreadMappingAttr::get(ctx, Threads::DimX); - Attribute tY = GPUThreadMappingAttr::get(ctx, Threads::DimY); - Attribute tZ = GPUThreadMappingAttr::get(ctx, Threads::DimZ); if (foreachThreadOp.getNumResults() > 0) return failureHelper( "only bufferized scf.foreach_thread lowers to gpu.thread_id"); @@ -389,12 +404,6 @@ return failureHelper("mapping must be present"); SmallVector threadMapping = llvm::to_vector(foreachThreadOp.getMapping()->getValue()); - if (llvm::any_of(threadMapping, [](DeviceMappingAttrInterface map) { - return !map.isa(); - })) { - return transformOp->emitSilenceableError() - << "mapping must be #gpu.thread"; - } // Step 1. Complete the threadMapping to a full mapping (with 1s) if // necessary. @@ -402,7 +411,7 @@ llvm::to_vector(foreachThreadOp.getNumThreads()); // Ensure we have 3 block sizes, one for each id. Value one; - for (auto attr : {tX, tY, tZ}) { + for (auto attr : threadMappingAttributes) { if (std::find(threadMapping.begin(), threadMapping.end(), attr) == threadMapping.end()) { threadMapping.push_back(attr); @@ -411,10 +420,10 @@ } } - // Step 2. sort the values by the corresponding GPUThreadMappingAttr. - auto comparator = [](Attribute a, Attribute b) -> bool { - return static_cast(a.cast().getThread()) < - static_cast(b.cast().getThread()); + // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. + auto comparator = [&](DeviceMappingAttrInterface a, + DeviceMappingAttrInterface b) -> bool { + return a.getMappingId() < b.getMappingId(); }; SmallVector blockDimValues = scf::ForeachThreadOp::getValuesSortedByKey(threadMapping, numThreads, @@ -434,8 +443,9 @@ BlockAndValueMapping bvm; for (auto [blockIdx, blockDim] : llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) { - bvm.map(blockIdx, threadOps[static_cast( - blockDim.cast().getThread())]); + bvm.map( + blockIdx, + threadOps[blockDim.cast().getMappingId()]); } // Step 4. Maybe create conditionals to predicate the region. @@ -501,12 +511,18 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl( RewriterBase &rewriter, Operation *target, const SmallVectorImpl &blockDim, bool syncAfterDistribute, - llvm::Optional transformOp) { + llvm::Optional transformOp, + const ArrayRef &threadMappingAttributes) { DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); target->walk([&](scf::ForeachThreadOp foreachThreadOp) { - rewriter.setInsertionPoint(foreachThreadOp); - diag = rewriteOneForeachThreadToGpuThreads( - rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp); + diag = checkAttributeType(threadMappingAttributes, + foreachThreadOp.getMapping(), transformOp); + if (diag.succeeded()) { + rewriter.setInsertionPoint(foreachThreadOp); + diag = rewriteOneForeachThreadToGpuThreads( + rewriter, foreachThreadOp, blockDim, syncAfterDistribute, transformOp, + threadMappingAttributes); + } return diag.succeeded() ? WalkResult::advance() : WalkResult::interrupt(); }); return diag; @@ -536,11 +552,19 @@ return diag; } - SimpleRewriter rewriter(getContext()); + MLIRContext *ctx = getContext(); + SimpleRewriter rewriter(ctx); rewriter.setInsertionPoint(target); + SmallVector threadMappingAttributes = { + GPUThreadMappingAttr::get(ctx, Threads::DimX), + GPUThreadMappingAttr::get(ctx, Threads::DimY), + GPUThreadMappingAttr::get(ctx, Threads::DimZ)}; + diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( - rewriter, target, blockDim, getSyncAfterDistribute(), transformOp); + rewriter, target, blockDim, getSyncAfterDistribute(), transformOp, + threadMappingAttributes); + if (diag.succeeded()) { diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, llvm::None, llvm::None, diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -160,7 +160,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 - %foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30] + %foreach, %tiled = transform.structured.tile_to_foreach_thread_op %matmul num_threads [10, 20, 30] (mapping = [ #gpu.thread, #gpu.thread, #gpu.thread ] ) %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 // expected-error @below {{only bufferized scf.foreach_thread lowers to gpu.thread_id}} transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [128, 4, 1] }