diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -172,6 +172,8 @@ #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc" diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/Dialect/GPU/IR/GPUBase.td" include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td" +include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt @@ -4,3 +4,8 @@ add_public_tablegen_target(MLIRGPUTransformOpsIncGen) add_mlir_doc(GPUTransformOps GPUTransformOps Dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS GPUDeviceMappingAttr.td) +mlir_tablegen(GPUDeviceMapperEnums.h.inc -gen-enum-decls) +mlir_tablegen(GPUDeviceMapperEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRGPUDeviceMapperEnumsGen) diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td @@ -0,0 +1,65 @@ +//===-- GPUDeviceMappingAttr.td - Attribute definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the attribute used to map loops to gpu. +// +//===----------------------------------------------------------------------===// + +#ifndef GPU_DEVICE_MAPPING_ATTR +#define GPU_DEVICE_MAPPING_ATTR + +include "mlir/Dialect/GPU/IR/GPUBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" + +def DimX : I64EnumAttrCase<"DimX", 0, "x">; +def DimY : I64EnumAttrCase<"DimY", 1, "y">; +def DimZ : I64EnumAttrCase<"DimZ", 2, "z">; + +def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUThreadMappingAttr + : GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> { + let parameters = (ins + EnumParameter:$thread + ); + let assemblyFormat = "`<` params `>`"; + let description = [{ + An attribute that allows defining thread parallelism for GPU devices. + + Thread (aka work item) are grouped into a thread blocks where block may be + described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates + that thread parallelism is desired. It can be consumed by lowering to + generate GPU. + }]; +} + +def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] > { + let parameters = (ins + EnumParameter:$block + ); + let assemblyFormat = "`<` params `>`"; + let description = [{ + An attribute that allows defining thread block parallelism for GPU devices. + + Thread blocks (aka work-group) are grouped into a grid where grid may be + described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates + that thread block parallelism is desired. It can be consumed by lowering to + generate GPU code. + }]; +} + +#endif // GPU_DEVICE_MAPPING_ATTR diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -29,7 +29,7 @@ The operation searches for `scf.foreach_thread` ops nested under `target` and maps each such op to GPU threads. Mapping is one-to-one and the induction variables of `scf.foreach_thread` are rewritten to - `gpu.thread_id` according to the `thread_dim_mapping` attribute. + `gpu.thread_id` according to the `mapping` attribute. Sibling `scf.foreach_thread` are supported in which case, the union of the number of threads is computed and may result in predication. @@ -73,10 +73,10 @@ threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) { scf.foreach_thread (%i, %j) in (7, 9) { ... // body 1 - } {thread_dim_mapping = [1, 0, 2]} + } {mapping = [#gpu.thread, #gpu.thread, #gpu.thread]} scf.foreach_thread (%i) in (12) { ... // body 2 - } + } {mapping = [#gpu.thread]} gpu.terminator } ``` diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -47,7 +47,7 @@ RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, ArrayRef mixedNumThreads, - ArrayRef mixedTileSizes, Optional threadDimMapping, + ArrayRef mixedTileSizes, Optional mapping, SmallVector &tileOps, SmallVector &tiledOps); } // namespace transform diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" @@ -792,7 +793,7 @@ a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case). - If non-empty, the `thread_dim_mapping` is added as an attribute to the + If non-empty, the `mapping` is added as an attribute to the resulting `scf.foreach_thread`. #### Return modes @@ -832,7 +833,7 @@ Variadic:$tile_sizes, DefaultValuedAttr:$static_num_threads, DefaultValuedAttr:$static_tile_sizes, - OptionalAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); @@ -841,22 +842,22 @@ "ArrayRef":$staticTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedTileSizes, CArg<"::mlir::transform::TileSizesSpec", "::mlir::transform::TileSizesSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$staticNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, OpBuilder<(ins "Value":$target, "ArrayRef":$mixedNumThreads, CArg<"::mlir::transform::NumThreadsSpec", "::mlir::transform::NumThreadsSpec()">, - CArg<"ArrayRef", "{}">:$threadDimMapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, ]; let assemblyFormat = [{ @@ -867,7 +868,7 @@ `tile_sizes` custom($tile_sizes, $static_tile_sizes, "ShapedType::kDynamicSize")) - (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict + (`(` `mapping` `=` $mapping^ `)`)? attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -423,7 +423,7 @@ /// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`, applying /// tiling by `numThreads`. -/// If non-empty, the `threadDimMapping` is added as an attribute to the +/// If non-empty, the `mapping` is added as an attribute to the /// resulting `scf.foreach_thread`. /// Zero tile sizes indicate that the dimension is not tiled, and can be /// thought of as tiling by the full size of data. It is the user's @@ -436,14 +436,14 @@ FailureOr tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping = {}); + Optional mapping); /// Same as `tileToForeachThreadOp`, but calculate the number of threads /// required using the given tileSizes. FailureOr tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping = {}); + Optional mapping); /// All indices returned by IndexOp should be invariant with respect to /// tiling. Therefore, if an operation is tiled, we have to transform the diff --git a/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SCF/IR/CMakeLists.txt @@ -1,3 +1,10 @@ add_mlir_dialect(SCFOps scf Ops) add_mlir_doc(SCFOps SCFDialect Dialects/ -gen-dialect-doc) +set(LLVM_TARGET_DEFINITIONS DeviceMappingInterface.td) +mlir_tablegen(DeviceMappingAttrInterface.h.inc -gen-attr-interface-decls) +mlir_tablegen(DeviceMappingAttrInterface.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(DeviceMappingAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(DeviceMappingAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRDeviceMappingInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRDeviceMappingInterfacesIncGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.h @@ -0,0 +1,22 @@ +//===- DeviceMappingInterface.h - -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the device mapping interface defined in +// `DeviceMappingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEVICEMAPPINGINTERFACE_H +#define MLIR_DEVICEMAPPINGINTERFACE_H + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc" + +#endif // MLIR_DEVICEMAPPINGINTERFACE_H diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td @@ -0,0 +1,43 @@ +//===- DeviceMappingInterface.td - Device mapping interfaces*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interfaces for the device mapping specification for the loops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEVICEMAPPINGINTERFACE +#define MLIR_DEVICEMAPPINGINTERFACE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Attribute interfaces +//===----------------------------------------------------------------------===// + +def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Attribute interface describing how to map a region to a processing unit. + + It is intended to be a generic mechanism for binding regions to execution + units of an actual or virtual device. Each device first expresses its own + mappings, and those mappings must implement this interface. These mappings + can be used by the device-specific code generators and the desired regions + can be connected to the given processing unit. + + Currently, `scf.foreach_thread` uses this interface to express the mapping + of the loops it contains to the GPU's parallelism units such as threads and + thread blocks. + }]; +} + +def DeviceMappingArrayAttr : + TypedArrayAttrBase { } + +#endif // MLIR_DEVICEMAPPINGINTERFACE diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SCF_SCF_H #define MLIR_DIALECT_SCF_SCF_H +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionKindInterface.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -378,14 +379,14 @@ application per thread. Further lowerings are responsible for specifying how this is materialized on concrete hardware resources. - An optional thread_dim_mapping index array attribute specifies for each - virtual thread dimension, how it remaps 1-1 to a set of concrete processing + An optional `mapping` is an attribute array that specifies processing units + with their dimension, how it remaps 1-1 to a set of concrete processing element resources (e.g. a CUDA grid dimension or a level of concrete nested - async parallelism). At this time, the specification is backend-dependent and - is not verified by the op, beyond being an index array attribute. - It is the reponsibility of the lowering to interpret the index array in the - context of the concrete target the op is lowered to, or to ignore it when - the specification is ill-formed or unsupported for a particular target. + async parallelism). It is expressed via any attribute that implements the + device mapping interface. It is the reponsibility of the lowering mechanism + to interpret the `mapping` attributes in the context of the concrete target + the op is lowered to, or to ignore it when the specification is ill-formed + or unsupported for a particular target. The only allowed terminator is `scf.foreach_thread.perform_concurrently`. `scf.foreach_thread` returns one value per `shared_out` operand. The @@ -440,11 +441,12 @@ // ``` - Example with thread_dim_mapping attribute: + Example with mapping attribute: ```mlir // - // Sequential context. + // Sequential context. Here `mapping` is expressed as GPU thread mapping + // attributes // %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in (%num_threads_1, %numthread_id_2) shared_outs(...) @@ -456,7 +458,7 @@ scf.foreach_thread.perform_concurrently { ... } - } { thread_dim_mapping = [1, 0] } + } { mapping = [#gpu.thread, #gpu.thread] } // Implicit synchronization point. // Sequential context. // @@ -480,7 +482,7 @@ }]; let arguments = (ins Variadic:$num_threads, Variadic:$outputs, - DefaultValuedAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -493,10 +495,10 @@ let builders = [ // Bodyless builder, outputs must be specified. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - CArg<"ArrayRef", "{}">:$thread_dim_mapping)>, + "Optional":$mapping)>, // Builder that takes a bodyBuilder lambda. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - "ArrayRef":$thread_dim_mapping, + "ArrayRef":$mapping, "function_ref":$bodyBuilder)> ]; let extraClassDeclaration = [{ @@ -535,14 +537,14 @@ } /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is - /// thread_dim_mapping is not a valid permutation. - FailureOr> getPermutedThreadIndices(); + /// given mapping argument. Return failure is + /// mapping is not a valid permutation. + FailureOr> getPermutedThreadIndices(ArrayRef mapping); /// Return the number of threads in the order specified by the - /// thread_dim_mapping attribute. - /// Return failure is thread_dim_mapping is not a valid permutation. - FailureOr> getPermutedNumThreads(OpBuilder &b); + /// given mapping argument. + /// Return failure is mapping is not a valid permutation. + FailureOr> getPermutedNumThreads(OpBuilder &b, ArrayRef mapping); // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be diff --git a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt @@ -3,10 +3,13 @@ ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/TransformOps + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces DEPENDS MLIRGPUTransformOpsIncGen - + MLIRDeviceMappingInterfacesIncGen + MLIRGPUDeviceMapperEnumsGen + LINK_LIBS PUBLIC MLIRIR MLIRGPUTransforms diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -166,8 +166,20 @@ // Step 0. Outline the compute workload region and set up the workload // operands. + SmallVector mapping; + if (!foreachThreadOp.getMapping().has_value()) + return transformOp.emitSilenceableError() << "mapping must be present"; + for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) { + if (auto blockMap = map.dyn_cast()) { + mapping.push_back((int64_t)blockMap.getBlock()); + } else { + return transformOp.emitSilenceableError() + << "mapping must be #gpu.block"; + } + } + FailureOr> potentialGridDim = - foreachThreadOp.getPermutedNumThreads(rewriter); + foreachThreadOp.getPermutedNumThreads(rewriter, mapping); if (failed(potentialGridDim) || llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { @@ -193,7 +205,7 @@ // Step 2. RAUW thread indices to thread ops. SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(); + *foreachThreadOp.getPermutedThreadIndices(mapping); assert(blockOps.size() == 3 && "3 block id ops are required"); for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) { Value val = blockIdx; @@ -230,7 +242,8 @@ } /// This is a helper that is only used in -/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id. +/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects +/// block_id. static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForeachThreadOp foreachOp, SmallVectorImpl &blockOps) { @@ -295,6 +308,7 @@ rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim, transformOp); if (diag.succeeded()) { + gridDim.resize(3, 1); diag = alterGpuLaunch(rewriter, gpuLaunch, cast(getOperation()), gridDim[0], gridDim[1], gridDim[2]); @@ -335,7 +349,18 @@ return failureHelper( "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id"); - auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter); + SmallVector mapping; + if (!foreachThreadOp.getMapping().has_value()) + return failureHelper("mapping must be present"); + for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) { + if (auto threadMap = map.dyn_cast()) { + mapping.push_back((int64_t)threadMap.getThread()); + } else { + return failureHelper("mapping must be #gpu.thread"); + } + } + auto potentialBlockDim = + foreachThreadOp.getPermutedNumThreads(rewriter, mapping); if (failed(potentialBlockDim) || llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) { return !getConstantIntValue(ofr).has_value(); @@ -347,6 +372,7 @@ llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) { return getConstantIntValue(ofr).value(); })); + blockDim.resize(3, 1); // Step 1. Create the gpu.thread ops Location loc = foreachThreadOp.getLoc(); @@ -365,7 +391,8 @@ if (blockDim > globalBlockDim) { return failureHelper( "The requested GPU threads are fewer than the number of loop trip " - "counts. Try to tile scf.foreach_thread before mapping or set small " + "counts. Try to tile scf.foreach_thread before mapping or set " + "small " "blockDim."); } if (blockDim == globalBlockDim) @@ -400,7 +427,7 @@ // Step 4. RAUW thread indices to thread ops. SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(); + *foreachThreadOp.getPermutedThreadIndices(mapping); for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) { Value val = threadIdx; Value op = threadOp; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1321,19 +1321,21 @@ // TileToForeachThreadOp //===----------------------------------------------------------------------===// -void transform::TileToForeachThreadOp::build( - OpBuilder &builder, OperationState &result, Value target, - ArrayRef staticTileSizes, transform::TileSizesSpec, - ArrayRef threadDimMapping) { +void transform::TileToForeachThreadOp::build(OpBuilder &builder, + OperationState &result, + Value target, + ArrayRef staticTileSizes, + transform::TileSizesSpec, + ArrayRef mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), - TileSizesSpec(), threadDimMapping); + TileSizesSpec(), mapping); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedTileSizes, transform::TileSizesSpec, - ArrayRef threadDimMapping) { + ArrayRef mapping) { SmallVector staticTileSizes; SmallVector dynamicTileSizes; dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes, @@ -1344,28 +1346,29 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes); - ArrayAttr threadDimMappingAttr; - if (!threadDimMapping.empty()) - threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + ArrayAttr mappingAttr; + if (!mapping.empty()) + mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, /*numThreads=*/ValueRange{}, dynamicTileSizes, - /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, - threadDimMappingAttr); + /*staticNumThreads=*/ArrayAttr(), staticTileSizesAttr, mappingAttr); } -void transform::TileToForeachThreadOp::build( - OpBuilder &builder, OperationState &result, Value target, - ArrayRef staticNumThreads, transform::NumThreadsSpec, - ArrayRef threadDimMapping) { +void transform::TileToForeachThreadOp::build(OpBuilder &builder, + OperationState &result, + Value target, + ArrayRef staticNumThreads, + transform::NumThreadsSpec, + ArrayRef mapping) { return build(builder, result, target, getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), - NumThreadsSpec(), threadDimMapping); + NumThreadsSpec(), mapping); } void transform::TileToForeachThreadOp::build( OpBuilder &builder, OperationState &result, Value target, ArrayRef mixedNumThreads, transform::NumThreadsSpec, - ArrayRef threadDimMapping) { + ArrayRef mapping) { SmallVector staticNumThreads; SmallVector dynamicNumThreads; dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, @@ -1376,19 +1379,19 @@ MLIRContext *ctx = builder.getContext(); auto operationType = pdl::OperationType::get(ctx); auto staticNumThreadsAttr = builder.getI64ArrayAttr(staticNumThreads); - ArrayAttr threadDimMappingAttr; - if (!threadDimMapping.empty()) - threadDimMappingAttr = builder.getI64ArrayAttr(threadDimMapping); + ArrayAttr mappingAttr; + if (!mapping.empty()) + mappingAttr = builder.getI64ArrayAttr(mapping); build(builder, result, TypeRange{operationType, operationType}, target, dynamicNumThreads, /*tileSizes=*/ValueRange{}, staticNumThreadsAttr, - /*staticTileSizes=*/ArrayAttr(), threadDimMappingAttr); + /*staticTileSizes=*/ArrayAttr(), mappingAttr); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( RewriterBase &rewriter, transform::TransformState &state, TransformOpInterface transformOp, ArrayRef targets, ArrayRef mixedNumThreads, - ArrayRef mixedTileSizes, Optional threadDimMapping, + ArrayRef mixedTileSizes, Optional mapping, SmallVector &tileOps, SmallVector &tiledOps) { if (targets.empty()) return DiagnosedSilenceableFailure(success()); @@ -1457,19 +1460,13 @@ return diag; } rewriter.setInsertionPoint(tilableOp); - auto maybeThreadDimMappingAttr = threadDimMapping; - auto dimMapping = llvm::to_vector( - maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); - FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, - numThreads, dimMapping); + numThreads, mapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, tileSizes, dimMapping); + rewriter, tilableOp, tileSizes, mapping); } if (failed(tilingResult)) @@ -1494,7 +1491,7 @@ DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( rewriter, state, cast(getOperation()), targets, - getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, + getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, tiledOps); if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -215,7 +215,7 @@ /// tiling is specified by the number of tiles/threads `numThreads` and the /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is /// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], -/// numThreads[i])`. If non-empty, the `threadDimMapping` is added as an +/// numThreads[i])`. If non-empty, the `mapping` is added as an /// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate /// that the dimension is not tiled, and can be thought of as tiling by the full /// size of data. @@ -226,7 +226,7 @@ static FailureOr tileToForeachThreadOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, Optional> nominalTileSizes, - ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { + Optional mapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); SmallVector loopRanges = op.getIterationDomain(b); @@ -256,7 +256,7 @@ // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. scf::ForeachThreadOp foreachThreadOp = b.create( - loc, dest, ValueRange(materializedNonZeroNumThreads), threadDimMapping); + loc, dest, ValueRange(materializedNonZeroNumThreads), mapping); // Fill out the ForeachThreadOp body. b.setInsertionPointToStart(foreachThreadOp.getBody(0)); @@ -363,16 +363,16 @@ FailureOr linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping) { + Optional mapping) { return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, - threadDimMapping, + mapping, /*omitTileOffsetBoundsCheck=*/false); } FailureOr -linalg::tileToForeachThreadOpUsingTileSizes( - RewriterBase &b, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping) { +linalg::tileToForeachThreadOpUsingTileSizes(RewriterBase &b, TilingInterface op, + ArrayRef tileSizes, + Optional mapping) { SmallVector loopRanges = op.getIterationDomain(b); unsigned nLoops = loopRanges.size(); SmallVector numThreads; @@ -388,8 +388,7 @@ numThreads.push_back(numTiles); } return tileToForeachThreadOpImpl(b, op, numThreads, - /*nominalTileSizes=*/tileSizes, - threadDimMapping, + /*nominalTileSizes=*/tileSizes, mapping, /*omitTileOffsetBoundsCheck=*/true); } diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt --- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFDialect SCF.cpp + DeviceMappingInterface.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/SCF diff --git a/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp b/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SCF/IR/DeviceMappingInterface.cpp @@ -0,0 +1,17 @@ +//===- DeviceMappingInterface.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Table-generated class definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.cpp.inc" diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/FunctionInterfaces.h" @@ -19,6 +20,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::scf; @@ -1111,6 +1113,12 @@ if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType()) return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; + if (getMapping().has_value()) + for (auto map : getMapping().value()) { + if (!isa(map)) + return emitOpError() + << getMappingAttrName() << " is not device mapping attribute"; + } return success(); } @@ -1200,11 +1208,14 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, ValueRange numThreads, - ArrayRef threadDimMapping) { + Optional mapping) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + if (mapping.has_value()) { + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + mapping.value()); + } + result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1231,12 +1242,12 @@ // Builder that takes a bodyBuilder lambda. void ForeachThreadOp::build( mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, - ValueRange numThreads, ArrayRef threadDimMapping, + ValueRange numThreads, ArrayRef mapping, function_ref bodyBuilder) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + builder.getArrayAttr(mapping)); result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1290,51 +1301,49 @@ SmallVector result(vals.size()); SmallVector seen(vals.size()); for (auto [idx, val] : llvm::zip(perm, vals)) { - // Already seen, invalid thread_dim_mapping. + // Already seen, invalid mapping. if (seen[idx]) return failure(); result[idx] = val; seen[idx] = true; } - // Some not seen, invalid thread_dim_mapping. + // Some not seen, invalid mapping. if (!llvm::all_of(seen, [](bool b) { return b; })) return failure(); return result; } -/// Helper to get apply the `thread_dim_mapping` permutation of a +/// Helper to get apply the `mapping` permutation of a /// `foreachThreadOp` to `values`. template static FailureOr> getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp, - const SmallVector &values) { + const SmallVector &values, + ArrayRef mapping) { // Apply mapping permutation if specified. - auto mapping = foreachThreadOp.getThreadDimMapping(); - if (mapping && !mapping.empty()) { - auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); - if (failed(maybePermuted)) - return foreachThreadOp->emitError("invalid permutation"); - return *maybePermuted; - } + auto maybePermuted = permute(values, mapping); + if (failed(maybePermuted)) + return foreachThreadOp->emitError("invalid permutation"); + return *maybePermuted; return values; } -/// Return the thread indices in the order specified by the thread_dim_mapping -/// attribute. Return failure is thread_dim_mapping is not a valid permutation. -FailureOr> ForeachThreadOp::getPermutedThreadIndices() { +/// Return the thread indices in the order specified by the mapping +/// attribute. Return failure is mapping is not a valid permutation. +FailureOr> +ForeachThreadOp::getPermutedThreadIndices(ArrayRef mapping) { SmallVector threadCountValues = this->getThreadIndices(); - threadCountValues.resize(3, Value()); - return getValuesPermutedByThreadMapping(*this, threadCountValues); + return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); } /// Return the number of threads in the order specified by the -/// thread_dim_mapping attribute. -/// Return failure is thread_dim_mapping is not a valid permutation. +/// mapping attribute. +/// Return failure is mapping is not a valid permutation. FailureOr> -ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) { +ForeachThreadOp::getPermutedNumThreads(OpBuilder &b, + ArrayRef mapping) { SmallVector threadCountValues = this->getNumThreads(); - threadCountValues.resize(3, b.getIndexAttr(1)); - return getValuesPermutedByThreadMapping(*this, threadCountValues); + return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); } ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1141,10 +1141,11 @@ // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. rewriter.setInsertionPoint(foreachThreadOp); - auto newForeachThreadOp = rewriter.create( + ForeachThreadOp newForeachThreadOp; + newForeachThreadOp = rewriter.create( foreachThreadOp.getLoc(), /*outputs=*/ValueRange(), - foreachThreadOp.getNumThreads(), - extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); + foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping()); + newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -26,7 +26,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -38,7 +38,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -67,7 +67,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -79,7 +79,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -106,7 +106,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2 x 32 x f32> @@ -131,7 +131,7 @@ scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) { %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32> memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2x32x32x32xf32> @@ -197,14 +197,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -232,14 +232,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } return %y : memref<2 x 32 x f32> } @@ -261,7 +261,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.block, #gpu.block] } return %y : memref<2 x 32 x f32> } @@ -273,4 +273,3 @@ } // ----- - diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -24,7 +24,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.block, #gpu.block]} gpu.terminator } return %y : !type @@ -33,7 +33,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_foreach_to_blocks %funcop { blockDim = [12, 9, 1]} + transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]} } // ----- @@ -73,12 +73,12 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread]} scf.foreach_thread (%i) in (%c12) { %7 = memref.load %t[%i] : !type1d %8 = arith.addf %alpha, %7 : f32 memref.store %8, %t[%i] : !type1d - } {thread_dim_mapping = [0, 1, 2]} + } {mapping = [#gpu.thread] } gpu.terminator } return %y : !type @@ -87,7 +87,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1] } + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } } // ----- @@ -118,8 +118,8 @@ %5 = memref.load %y[%i, %j, %k, %l] : !type4d %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j, %k, %l] : !type4d - } {thread_dim_mapping = [1, 0, 2]} - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } + } { mapping = [#gpu.block, #gpu.block] } return %y : !type4d } @@ -151,7 +151,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } return %y : !type diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -26,7 +26,7 @@ // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } - // CHECK-NEXT: } {thread_dim_mapping = [1, 0]} + // CHECK-NEXT: } {mapping = [#gpu.thread, #gpu.thread]} %0 = linalg.matmul ins(%A, %B : tensor, tensor) outs(%C : tensor) -> (tensor) return %0 : tensor @@ -35,7 +35,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapping = [ #gpu.thread, #gpu.thread ] ) } } @@ -177,7 +177,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] ( mapping = [#gpu.thread]) } } // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -628,7 +628,7 @@ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> } - } {thread_dim_mapping = []} + } return %4 : tensor<320xf32> } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -128,7 +128,7 @@ tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } - // CHECK: } {thread_dim_mapping = [5]} - } {thread_dim_mapping = [5]} + // CHECK: } {mapping = [#gpu.thread]} + } {mapping = [#gpu.thread]} return } diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -338,12 +338,12 @@ %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread - // CHECK-NEXT: } {thread_dim_mapping = [42]} + // CHECK-NEXT: } {mapping = [#gpu.thread]} // CHECK-NEXT: return scf.foreach_thread (%thread_idx) in (%num_threads) { scf.foreach_thread.perform_concurrently { } - } {thread_dim_mapping = [42]} + } {mapping = [#gpu.thread]} return } diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -217,7 +217,7 @@ Location loc = op.getLoc(); auto foreachOp = rewriter.create( loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(), - /*threadDimMapping=*/ArrayRef{}, + /*mapping=*/ArrayRef{}, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { unsigned numThreadIdRegionArgs = helper.getIterationSpaceSizes().size(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1833,7 +1833,10 @@ td_library( name = "SCFTdFiles", - srcs = ["include/mlir/Dialect/SCF/IR/SCFOps.td"], + srcs = [ + "include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td", + "include/mlir/Dialect/SCF/IR/SCFOps.td", + ], includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", @@ -1870,6 +1873,32 @@ deps = [":SCFTdFiles"], ) +gentbl_cc_library( + name = "SCFDeviceMappingInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-attr-interface-decls"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.h.inc", + ), + ( + ["-gen-attr-interface-defs"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttrInterface.cpp.inc", + ), + ( + ["-gen-attrdef-decls"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttributes.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir/Dialect/SCF/IR/DeviceMappingAttributes.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td", + deps = [":SCFTdFiles"], +) + gentbl_cc_library( name = "SCFPassIncGen", strip_include_prefix = "include", @@ -2794,6 +2823,7 @@ ":MemRefDialect", ":ParallelCombiningOpInterface", ":Pass", + ":SCFDeviceMappingInterfacesIncGen", ":SCFIncGen", ":SCFPassIncGen", ":Support", @@ -3623,6 +3653,7 @@ "include/mlir/Dialect/GPU/IR/GPUBase.td", "include/mlir/Dialect/GPU/IR/GPUOps.td", "include/mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td", + "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td", ], includes = ["include"], deps = [ @@ -3632,10 +3663,32 @@ ":InferIntRangeInterfaceTdFiles", ":LLVMOpsTdFiles", ":OpBaseTdFiles", + ":SCFTdFiles", ":SideEffectInterfacesTdFiles", ], ) +gentbl_cc_library( + name = "GPUDeviceMapperEnumsGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMapperEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMapperEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td", + deps = [ + ":GPUOpsTdFiles", + ":OpBaseTdFiles", + ], +) + gentbl_cc_library( name = "GPUBaseIncGen", strip_include_prefix = "include", @@ -3725,6 +3778,7 @@ ":InferTypeOpInterface", ":LLVMDialect", ":MemRefDialect", + ":SCFDialect", ":SideEffectInterfaces", "//llvm:Support", ], @@ -7694,6 +7748,7 @@ includes = ["include"], deps = [ ":PDLDialectTdFiles", + ":SCFTdFiles", ":TransformDialectTdFiles", ], ) @@ -7781,6 +7836,7 @@ td_file = "include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td", deps = [ ":LinalgTransformOpsTdFiles", + ":SCFDeviceMappingInterfacesIncGen", ], )