Index: mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td =================================================================== --- mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td +++ mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td @@ -43,6 +43,27 @@ }]; } +def WarpsEnum : I64EnumAttr<"Warps", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUWarpMappingAttr : GPU_Attr<"GPUWarpMapping", "warp", [ + DeclareAttrInterfaceMethods ] > { + let parameters = (ins + EnumParameter:$warp + ); + let assemblyFormat = "`<` params `>`"; + let description = [{ + An attribute that allows defining thread block parallelism for GPU devices. + + Warp (aka subgroup) are grouped into a grid where grid may be + described by a 1-, 2-, or 3-dimensional rectangle. This attribute indicates + that thread block parallelism is desired. It can be consumed by lowering to + generate GPU code. + }]; +} + def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [ DimX, DimY, DimZ]> { let cppNamespace = "::mlir::gpu"; Index: mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -58,9 +58,13 @@ If any scf.foreach_thread with tensors is found, the transform definitely fails. - If all the scf.foreach_thread operations contained within the LaunchOp - referred to by the `target` PDLOperation lower to GPU properly, the - transform succeeds. Otherwise the transform definitely fails. + If all the scf.foreach_thread operations with gpu.thread mapping contained + within the LaunchOp referred to by the `target` PDLOperation lower to GPU + properly, the transform succeeds. Otherwise the transform definitely + fails. + + scf.foreach_thread operations with mappings other than gpu.thread are + ignored. The returned handle points to the same LaunchOp operand, consuming it and producing a new SSA value to satisfy chaining and linearity of the IR Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -42,6 +42,10 @@ return static_cast(getBlock()); } +int64_t GPUWarpMappingAttr::getMappingId() const { + return static_cast(getWarp()); +} + int64_t GPUThreadMappingAttr::getMappingId() const { return static_cast(getThread()); } Index: mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp =================================================================== --- mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -509,6 +509,12 @@ const ArrayRef &threadMappingAttributes) { DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success(); target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + // Ignore cases with different attributes. + for (Attribute map : foreachThreadOp.getMapping()->getValue()) { + if (!llvm::is_contained(threadMappingAttributes, map)) { + return WalkResult::skip(); + } + } diag = checkAttributeType(threadMappingAttributes, foreachThreadOp.getMapping(), transformOp); if (diag.succeeded()) { Index: mlir/test/Dialect/GPU/transform-gpu-failing.mlir =================================================================== --- mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -274,30 +274,4 @@ transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]} } -// ----- - -!type = memref<32x32xf32> -func.func @saxpy2d_wrong_mapping(%x: !type, %y: !type, %stream : !gpu.async.token) -> !type { - %c32 = arith.constant 32 : index - %one = arith.constant 1 : index - %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) - threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) - { - scf.foreach_thread (%i, %j) in (%c32, %c32) { - %4 = memref.load %x[%i, %j] : !type - %5 = memref.load %y[%i, %j] : !type - %6 = arith.mulf %4, %5 : f32 - memref.store %6, %y[%i, %j] : !type - } { mapping = [#gpu.block, #gpu.block] } - gpu.terminator - } - return %y : !type -} - -transform.sequence failures(propagate) { -^bb1(%arg0: !pdl.operation): - %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation - // expected-error @below {{mapping must be one of #gpu.thread, #gpu.thread, #gpu.thread}} - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32, 32]} -} Index: mlir/test/Dialect/GPU/transform-gpu.mlir =================================================================== --- mlir/test/Dialect/GPU/transform-gpu.mlir +++ mlir/test/Dialect/GPU/transform-gpu.mlir @@ -230,3 +230,42 @@ %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } } + +// ----- + +!type = memref<2 x 32 x f32> +!type1d = memref<32 x f32> + +// CHECK-LABEL: func.func @map_multi_level( +func.func @map_multi_level(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %one = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index +// check that the thread level got distributed but not the warp level. +// CHECK-NOT: {mapping = #gpu.thread +// CHECK: {mapping = [#gpu.warp]} + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j) in (%c7, %c9) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } { mapping = [#gpu.thread, #gpu.thread]} + scf.foreach_thread (%i) in (%c12) { + %7 = memref.load %t[%i] : !type1d + %8 = arith.addf %alpha, %7 : f32 + memref.store %8, %t[%i] : !type1d + } {mapping = [#gpu.warp] } + gpu.terminator + } + return %y : !type +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!pdl.operation) -> !pdl.operation + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } +}