diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -339,6 +339,15 @@ application per thread. Further lowerings are responsible for specifying how this is materialized on concrete hardware resources. + An optional thread_dim_mapping index array attribute specifies for each + virtual thread dimension, how it remaps 1-1 to a set of concrete processing + element resources (e.g. a CUDA grid dimension or a level of concrete nested + async parallelism). At this time, the specification is backend-dependent and + is not verified by the op, beyond being an index array attribute. + It is the reponsibility of the lowering to interpret the index array in the + context of the concrete target the op is lowered to, or to ignore it when + the specification is ill-formed or unsupported for a particular target. + The only allowed terminator is `scf.foreach_thread.perform_concurrently`, which dictates how the partial results of all parallel invocations should be reconciled into a full value. @@ -398,8 +407,27 @@ // Sequential context. // ``` + + Example with thread_dim_mapping attribute: + // + // Sequential context. + // + %matmul_and_pointwise:2 = scf.foreach_thread (%thread_id_1, %thread_id_2) in + (%num_threads_1, %numthread_id_2) -> (tensor, tensor) { + // + // Parallel context, each thread with id = **(%thread_id_2, %thread_id_1)** + // runs its version of the code. + // + scf.foreach_thread.perform_concurrently { + ... + } + } { thread_dim_mapping = [1, 0] } + // Implicit synchronization point. + // Sequential context. + // }]; - let arguments = (ins Variadic:$num_threads); + let arguments = (ins Variadic:$num_threads, + DefaultValuedAttr:$thread_dim_mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -411,11 +439,13 @@ let skipDefaultBuilders = 1; let builders = [ // Bodyless builder, result types must be specified. - OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads)>, + OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$num_threads, + CArg<"ArrayRef", "{}">:$thread_dim_mapping)>, // Builder that takes a bodyBuilder lambda, result types are inferred from // the terminator. OpBuilder<(ins "ValueRange":$num_threads, - "function_ref":$bodyBuilder)> + "ArrayRef":$thread_dim_mapping, + "function_ref":$bodyBuilder)> ]; let extraClassDeclaration = [{ int64_t getRank() { return getNumThreads().size(); } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1135,8 +1135,12 @@ // Bodyless builder, result types must be specified. void ForeachThreadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, TypeRange resultTypes, - ValueRange numThreads) { + ValueRange numThreads, + ArrayRef threadDimMapping) { result.addOperands(numThreads); + result.addAttribute( + // TODO: getThreadDimMappingAttrName() but it is not a static member. + "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping)); Region *bodyRegion = result.addRegion(); OpBuilder::InsertionGuard g(builder); @@ -1156,9 +1160,12 @@ // the terminator. void ForeachThreadOp::build( mlir::OpBuilder &builder, mlir::OperationState &result, - ValueRange numThreads, + ValueRange numThreads, ArrayRef threadDimMapping, function_ref bodyBuilder) { result.addOperands(numThreads); + result.addAttribute( + // TODO: getThreadDimMappingAttrName() but it is not a static member. + "thread_dim_mapping", builder.getI64ArrayAttr(threadDimMapping)); OpBuilder::InsertionGuard g(builder); Region *bodyRegion = result.addRegion(); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -999,7 +1000,8 @@ TypeRange newResultTypes; auto newForeachThreadOp = rewriter.create( foreachThreadOp.getLoc(), newResultTypes, - foreachThreadOp.getNumThreads()); + foreachThreadOp.getNumThreads(), + extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -130,6 +130,7 @@ scf.foreach_thread.parallel_insert_slice %1 into %out[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } - } + // CHECK: } {thread_dim_mapping = [5]} + } {thread_dim_mapping = [5]} return } diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -338,11 +338,11 @@ %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread - // CHECK-NEXT: } + // CHECK-NEXT: } {thread_dim_mapping = [42]} // CHECK-NEXT: return scf.foreach_thread (%thread_idx) in (%num_threads) -> () { scf.foreach_thread.perform_concurrently { } - } + } {thread_dim_mapping = [42]} return }