diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -434,6 +434,12 @@ rewriter.create(loc, indexType, Dimension::x), rewriter.create(loc, indexType, Dimension::y), rewriter.create(loc, indexType, Dimension::z)}; + // Replace ids of dimension size 1 by zero to simplify the IR. + Value zero = rewriter.create(loc, 0); + for (size_t i : llvm::seq(size_t(0), globalBlockDims.size())) { + if (globalBlockDims[i] == 1) + threadOps[i] = zero; + } IRMapping bvm; for (auto [blockIdx, blockDim] : llvm::zip(foreachThreadOp.getThreadIndices(), threadMapping)) { diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -194,3 +194,39 @@ %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [32]} } + +// ----- + +!type = memref<3 x 2 x 32 x f32> +!type1d = memref<32 x f32> + +// CHECK-LABEL: func.func @saxpy3d_fold_id_z( +func.func @saxpy3d_fold_id_z(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %one = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK-NOT: gpu.thread_id z + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j, %k) in (%one, %c7, %c9) { +// CHECK: memref.load %{{.*}}[%[[C0]], +// CHECK: memref.load %{{.*}}[%[[C0]], + %4 = memref.load %x[%i, %j, %k] : !type + %5 = memref.load %y[%i, %j, %k] : !type + %6 = math.fma %alpha, %4, %5 : f32 +// CHECK: memref.store %{{.*}}, %{{.*}}[%[[C0]] + memref.store %6, %y[%i, %j, %k] : !type + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } + gpu.terminator + } + return %y : !type +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } +}