diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -147,6 +147,47 @@ return DiagnosedSilenceableFailure::success(); } +/// Assings potentialBlockDim with the trip count of foreach_thread. If +/// it has values, assign the larger one. +static LogicalResult +setLargestTripCount(RewriterBase &rewriter, + scf::ForeachThreadOp foreachThreadOp, + SmallVectorImpl &potentialBlockDim) { + auto tripCounts = foreachThreadOp.getPermutedNumThreads(rewriter); + if (failed(tripCounts) || llvm::any_of(*tripCounts, [](OpFoldResult ofr) { + return !getConstantIntValue(ofr).has_value(); + })) { + return failure(); + } + + SmallVector tripCountVals = + llvm::to_vector(llvm::map_range(*tripCounts, [](OpFoldResult ofr) { + return getConstantIntValue(ofr).value(); + })); + + if (potentialBlockDim.empty()) { + for (auto dim : tripCountVals) + potentialBlockDim.push_back(dim); + } else { + for (size_t i = 0; i < tripCountVals.size(); ++i) + potentialBlockDim[i] = std::max(potentialBlockDim[i], tripCountVals[i]); + } + return success(); +} + +// Traverses all sibling foreach_thread ops, finds the largest number of +// trips in the same level +static LogicalResult setBlockDim(RewriterBase &rewriter, Operation *target, + SmallVectorImpl &blockDim) { + + auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) { + if (failed(setLargestTripCount(rewriter, foreachThreadOp, blockDim))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return walkResult.wasInterrupted() ? failure() : success(); +} + //===----------------------------------------------------------------------===// // MapForeachToBlocks //===----------------------------------------------------------------------===// @@ -448,27 +489,28 @@ return emitSilenceableError() << "Given target is not gpu.launch"; } + SimpleRewriter rewriter(getContext()); SmallVector blockDim = extractFromI64ArrayAttr(getBlockDim()); - blockDim.resize(/*size=*/3, /*value=*/1); + if (blockDim.empty() && failed(setBlockDim(rewriter, target, blockDim))) { + return emitSilenceableError() + << "Cannot assign blockDim, trip counts are not " + "known at compile-time."; + } DiagnosedSilenceableFailure diag = checkGpuLimits(transformOp, llvm::None, llvm::None, llvm::None, blockDim[0], blockDim[1], blockDim[2]); - if (diag.isSilenceableFailure()) { - results.assign({target}); + if (!diag.succeeded()) { diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large"; - return diag; - } - - SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(target); - - diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( - rewriter, target, blockDim, getSyncAfterDistribute(), transformOp); - if (diag.succeeded()) { - diag = - alterGpuLaunch(rewriter, gpuLaunch, transformOp, llvm::None, llvm::None, - llvm::None, blockDim[0], blockDim[1], blockDim[2]); + } else { + rewriter.setInsertionPoint(target); + diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl( + rewriter, target, blockDim, getSyncAfterDistribute(), transformOp); + if (diag.succeeded()) { + diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, llvm::None, + llvm::None, llvm::None, blockDim[0], blockDim[1], + blockDim[2]); + } } results.assign({gpuLaunch}); diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -162,3 +162,43 @@ %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } } + +// ----- + +!type = memref<2 x 32 x f32> +!type1d = memref<32 x f32> + +// CHECK-LABEL: func.func @map_nested_foreach_to_threads_without_blockdim( +func.func @map_nested_foreach_to_threads_without_blockdim(%x: !type, %y: !type, %t: !type1d, %alpha : f32, %stream : !gpu.async.token) -> !type { + %one = arith.constant 1 : index + %c12 = arith.constant 12 : index + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C12:.*]] = arith.constant 12 : index +// CHECK: %[[C7:.*]] = arith.constant 7 : index +// CHECK: gpu.launch async [%{{.*}}] blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C1]], %{{.*}} = %[[C1]], %{{.*}} = %[[C1]]) threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C12]], %{{.*}} = %[[C7]], %{{.*}} = %[[C1]]) + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.foreach_thread (%i, %j) in (%c7, %c9) { + %4 = memref.load %x[%i, %j] : !type + %5 = memref.load %y[%i, %j] : !type + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i, %j] : !type + } {thread_dim_mapping = [1, 0, 2]} + scf.foreach_thread (%i) in (%c12) { + %7 = memref.load %t[%i] : !type1d + %8 = arith.addf %alpha, %7 : f32 + memref.store %8, %t[%i] : !type1d + } {thread_dim_mapping = [0, 1, 2]} + gpu.terminator + } + return %y : !type +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !pdl.operation): + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 + transform.gpu.map_nested_foreach_to_threads %funcop +}