diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -29,7 +29,7 @@ The operation searches for `scf.foreach_thread` ops nested under `target` and maps each such op to GPU threads. Mapping is one-to-one and the induction variables of `scf.foreach_thread` are rewritten to - `gpu.thread_id` according to the `thread_dim_mapping` attribute. + `gpu.thread_id` according to the `mapping` attribute. Sibling `scf.foreach_thread` are supported in which case, the union of the number of threads is computed and may result in predication. @@ -73,7 +73,7 @@ threads(%tx, %ty, %tz) in (%tx = %3, %ty = %4, %tz = %5) { scf.foreach_thread (%i, %j) in (7, 9) { ... // body 1 - } {thread_dim_mapping = [1, 0, 2]} + } {mapping = ["thread_y", "thread_x"]} scf.foreach_thread (%i) in (12) { ... // body 2 } diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -727,7 +727,7 @@ Variadic:$tile_sizes, DefaultValuedAttr:$static_num_threads, DefaultValuedAttr:$static_tile_sizes, - OptionalAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); let assemblyFormat = [{ @@ -738,7 +738,7 @@ `tile_sizes` custom($tile_sizes, $static_tile_sizes, "ShapedType::kDynamicSize")) - (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict + (`(` `mapped` `to` `dims` $mapping^ `)`)? attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -439,14 +439,14 @@ FailureOr tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping = {}); + ArrayRef threadDimMapping = {}); /// Same as `tileToForeachThreadOp`, but calculate the number of threads /// required using the given tileSizes. FailureOr tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping = {}); + ArrayRef threadDimMapping = {}); /// All indices returned by IndexOp should be invariant with respect to tiling. /// Therefore, if an operation is tiled, we have to transform the indices diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -371,14 +371,13 @@ application per thread. Further lowerings are responsible for specifying how this is materialized on concrete hardware resources. - An optional thread_dim_mapping index array attribute specifies for each - virtual thread dimension, how it remaps 1-1 to a set of concrete processing - element resources (e.g. a CUDA grid dimension or a level of concrete nested - async parallelism). At this time, the specification is backend-dependent and - is not verified by the op, beyond being an index array attribute. - It is the reponsibility of the lowering to interpret the index array in the - context of the concrete target the op is lowered to, or to ignore it when - the specification is ill-formed or unsupported for a particular target. + An optional `mapping` string array attribute specifies for each virtual + thread dimension, how it remaps 1-1 to a set of concrete processing element + resources (e.g. a CUDA grid dimension or a level of concrete nested async + parallelism). Mapping can be one of following thread_x, thread_y, thread_z, + block_x, block_y, and block_z. in this way, one can use 3 dimensional threads + and blocks that contains threads. The threads and blocks cannot be mixed at + this point or used together. The only allowed terminator is `scf.foreach_thread.perform_concurrently`. `scf.foreach_thread` returns one value per `shared_out` operand. The @@ -433,7 +432,7 @@ // ``` - Example with thread_dim_mapping attribute: + Example with mapping attribute: ```mlir // @@ -449,7 +448,7 @@ scf.foreach_thread.perform_concurrently { ... } - } { thread_dim_mapping = [1, 0] } + } { mapping = ["thread_y", "thread_x"] } // Implicit synchronization point. // Sequential context. // @@ -473,7 +472,7 @@ }]; let arguments = (ins Variadic:$num_threads, Variadic:$outputs, - DefaultValuedAttr:$thread_dim_mapping); + DefaultValuedAttr:$mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -486,10 +485,10 @@ let builders = [ // Bodyless builder, outputs must be specified. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - CArg<"ArrayRef", "{}">:$thread_dim_mapping)>, + CArg<"ArrayRef", "{}">:$mapping)>, // Builder that takes a bodyBuilder lambda. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - "ArrayRef":$thread_dim_mapping, + "ArrayRef":$mapping, "function_ref":$bodyBuilder)> ]; let extraClassDeclaration = [{ @@ -528,13 +527,12 @@ } /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is - /// thread_dim_mapping is not a valid permutation. + /// mapping attribute. Return failure is mapping is not a valid permutation. FailureOr> getPermutedThreadIndices(); /// Return the number of threads in the order specified by the - /// thread_dim_mapping attribute. - /// Return failure is thread_dim_mapping is not a valid permutation. + /// mapping attribute. + /// Return failure is mapping is not a valid permutation. FailureOr> getPermutedNumThreads(OpBuilder &b); // The ensureTerminator method generated by SingleBlockImplicitTerminator is diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -57,6 +57,9 @@ /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. SmallVector extractFromI64ArrayAttr(Attribute attr); +/// Extract StringRef values from the assumed ArrayAttr of StringAttr. +SmallVector extractFromStringArrayAttr(Attribute attr); + /// Given a value, try to extract a constant Attribute. If this fails, return /// the original value. OpFoldResult getAsOpFoldResult(Value val); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -295,6 +295,8 @@ rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim, transformOp); if (diag.succeeded()) { + // Set 1 to dim that aren't used + gridDim.resize(3, 1); diag = alterGpuLaunch(rewriter, gpuLaunch, cast(getOperation()), gridDim[0], gridDim[1], gridDim[2]); @@ -347,6 +349,7 @@ llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) { return getConstantIntValue(ofr).value(); })); + blockDim.resize(3, 1); // Step 1. Create the gpu.thread ops Location loc = foreachThreadOp.getLoc(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1302,8 +1302,8 @@ auto maybeThreadDimMappingAttr = threadDimMapping; auto dimMapping = llvm::to_vector( maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); + ? extractFromStringArrayAttr(*maybeThreadDimMappingAttr) + : ArrayRef{}); FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { @@ -1336,7 +1336,7 @@ DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( rewriter, state, cast(getOperation()), targets, - getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, + getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, tiledOps); if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -226,7 +226,7 @@ static FailureOr tileToForeachThreadOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, Optional> nominalTileSizes, - ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { + ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); SmallVector loopRanges = op.getIterationDomain(b); @@ -363,7 +363,7 @@ FailureOr linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping) { + ArrayRef threadDimMapping) { return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, threadDimMapping, /*omitTileOffsetBoundsCheck=*/false); @@ -372,7 +372,7 @@ FailureOr linalg::tileToForeachThreadOpUsingTileSizes( RewriterBase &b, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping) { + ArrayRef threadDimMapping) { SmallVector loopRanges = op.getIterationDomain(b); unsigned nLoops = loopRanges.size(); SmallVector numThreads; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1069,6 +1069,29 @@ // ForeachThreadOp //===----------------------------------------------------------------------===// +// Thread mapping +enum class ThreadMap { + none = 0, + tidx = 1, + tidy = 2, + tidz = 4, + bidx = 8, + bidy = 10, + bidz = 20 +}; + +static Optional symbolizeThreadMap(StringRef str) { + return StringSwitch>(str) + .Case("none", ThreadMap::none) + .Case("thread_x", ThreadMap::tidx) + .Case("thread_y", ThreadMap::tidy) + .Case("thread_z", ThreadMap::tidz) + .Case("block_x", ThreadMap::bidx) + .Case("block_y", ThreadMap::bidy) + .Case("block_z", ThreadMap::bidz) + .Default(None); +} + LogicalResult ForeachThreadOp::verify() { // Call terminator's verify to produce most informative error messages. if (failed(getTerminator().verify())) @@ -1093,6 +1116,26 @@ return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; + bool mapBlock = false, mapThread = false; + for (auto map : getMapping()) { + auto threadMap = symbolizeThreadMap(map.cast().getValue()); + if (!threadMap.has_value()) + return emitOpError("expected ") + << "mapping to be one of: none, thread_x, thread_y, thread_z, " + "block_x, block_y, and block_z"; + if (threadMap.value() == ThreadMap::tidx || + threadMap.value() == ThreadMap::tidy || + threadMap.value() == ThreadMap::tidz) + mapThread = true; + if (threadMap.value() == ThreadMap::bidx || + threadMap.value() == ThreadMap::bidy || + threadMap.value() == ThreadMap::bidz) + mapBlock = true; + if (mapBlock && mapThread) + return emitOpError("Mapping blocks and threads cannot be mixed in the " + "same foreach_thread"); + } + return success(); } @@ -1181,11 +1224,11 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, ValueRange numThreads, - ArrayRef threadDimMapping) { + ArrayRef mapping) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + builder.getStrArrayAttr(mapping)); result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1212,12 +1255,12 @@ // Builder that takes a bodyBuilder lambda. void ForeachThreadOp::build( mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, - ValueRange numThreads, ArrayRef threadDimMapping, + ValueRange numThreads, ArrayRef mapping, function_ref bodyBuilder) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + builder.getStrArrayAttr(mapping)); result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1265,34 +1308,48 @@ template static FailureOr> permute(const SmallVector &vals, - ArrayRef perm) { + ArrayRef perm) { if (vals.size() != perm.size()) return failure(); SmallVector result(vals.size()); SmallVector seen(vals.size()); - for (auto [idx, val] : llvm::zip(perm, vals)) { - // Already seen, invalid thread_dim_mapping. - if (seen[idx]) + for (auto [mappingName, val] : llvm::zip(perm, vals)) { + auto maybeMap = symbolizeThreadMap(mappingName); + if (!maybeMap.has_value()) return failure(); - result[idx] = val; + ThreadMap mapping = maybeMap.value(); + int idx = 0; + if (mapping == ThreadMap::tidx || mapping == ThreadMap::bidx) + idx = 0; + else if (mapping == ThreadMap::tidy || mapping == ThreadMap::bidy) + idx = 1; + else if (mapping == ThreadMap::tidz || mapping == ThreadMap::bidz) + idx = 2; + // Already seen, invalid mapping. + if (seen[idx]) { + return failure(); + } + if (mapping != ThreadMap::none) + result[idx] = val; seen[idx] = true; } - // Some not seen, invalid thread_dim_mapping. + // Some not seen, invalid mapping. if (!llvm::all_of(seen, [](bool b) { return b; })) return failure(); return result; } -/// Helper to get apply the `thread_dim_mapping` permutation of a +/// Helper to get apply the `mapping` permutation of a /// `foreachThreadOp` to `values`. template static FailureOr> getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp, const SmallVector &values) { // Apply mapping permutation if specified. - auto mapping = foreachThreadOp.getThreadDimMapping(); + auto mapping = foreachThreadOp.getMapping(); + auto maps = extractFromStringArrayAttr(mapping); if (mapping && !mapping.empty()) { - auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); + auto maybePermuted = permute(values, maps); if (failed(maybePermuted)) return foreachThreadOp->emitError("invalid permutation"); return *maybePermuted; @@ -1300,21 +1357,19 @@ return values; } -/// Return the thread indices in the order specified by the thread_dim_mapping -/// attribute. Return failure is thread_dim_mapping is not a valid permutation. +/// Return the thread indices in the order specified by the mapping +/// attribute. Return failure is mapping is not a valid permutation. FailureOr> ForeachThreadOp::getPermutedThreadIndices() { SmallVector threadCountValues = this->getThreadIndices(); - threadCountValues.resize(3, Value()); return getValuesPermutedByThreadMapping(*this, threadCountValues); } /// Return the number of threads in the order specified by the -/// thread_dim_mapping attribute. -/// Return failure is thread_dim_mapping is not a valid permutation. +/// mapping attribute. +/// Return failure is mapping is not a valid permutation. FailureOr> ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) { SmallVector threadCountValues = this->getNumThreads(); - threadCountValues.resize(3, b.getIndexAttr(1)); return getValuesPermutedByThreadMapping(*this, threadCountValues); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1144,7 +1144,7 @@ auto newForeachThreadOp = rewriter.create( foreachThreadOp.getLoc(), /*outputs=*/ValueRange(), foreachThreadOp.getNumThreads(), - extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); + extractFromStringArrayAttr(foreachThreadOp.getMapping())); newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -65,6 +65,14 @@ })); } +/// Extract StringRef values from the assumed ArrayAttr of StringAttr. +SmallVector extractFromStringArrayAttr(Attribute attr) { + return llvm::to_vector( + llvm::map_range(attr.cast(), [](Attribute a) -> StringRef { + return a.cast().getValue(); + })); +} + /// Given a value, try to extract a constant Attribute. If this fails, return /// the original value. OpFoldResult getAsOpFoldResult(Value val) { diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -26,7 +26,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } @@ -38,7 +38,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } @@ -67,7 +67,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } @@ -79,7 +79,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } @@ -106,7 +106,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } return %y : memref<2 x 32 x f32> @@ -131,7 +131,7 @@ scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) { %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32> memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x", "thread_z"] } gpu.terminator } return %y : memref<2x32x32x32xf32> @@ -197,14 +197,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } @@ -232,14 +232,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = ["thread_x", "thread_y"] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } return %y : memref<2 x 32 x f32> } @@ -261,7 +261,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = ["thread_x", "thread_y"] } return %y : memref<2 x 32 x f32> } diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -24,7 +24,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = ["block_x", "block_y"]} gpu.terminator } return %y : !type @@ -33,7 +33,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_foreach_to_blocks %funcop { blockDim = [12, 9, 1]} + transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]} } // ----- @@ -73,12 +73,12 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"]} scf.foreach_thread (%i) in (%c12) { %7 = memref.load %t[%i] : !type1d %8 = arith.addf %alpha, %7 : f32 memref.store %8, %t[%i] : !type1d - } {thread_dim_mapping = [0, 1, 2]} + } {mapping = ["thread_x"] } gpu.terminator } return %y : !type @@ -87,7 +87,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1] } + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } } // ----- @@ -118,8 +118,8 @@ %5 = memref.load %y[%i, %j, %k, %l] : !type4d %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j, %k, %l] : !type4d - } {thread_dim_mapping = [1, 0, 2]} - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = ["thread_y", "thread_x"] } + } { mapping = ["block_x", "block_y"] } return %y : !type4d } @@ -127,7 +127,7 @@ ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["func.func"]} in %arg0 %gpuLaunch = transform.gpu.map_foreach_to_blocks %funcop { generate_gpu_launch } - transform.gpu.map_nested_foreach_to_threads %gpuLaunch { blockDim = [32, 4, 1] } + transform.gpu.map_nested_foreach_to_threads %gpuLaunch { blockDim = [32, 4] } } // ----- @@ -151,7 +151,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = ["thread_y", "thread_x"] } gpu.terminator } return %y : !type @@ -160,5 +160,5 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1], syncAfterDistribute = false } + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9], syncAfterDistribute = false } } diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -26,7 +26,7 @@ // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } - // CHECK-NEXT: } {thread_dim_mapping = [1, 0]} + // CHECK-NEXT: } {mapping = ["thread_y", "thread_x"]} %0 = linalg.matmul ins(%A, %B : tensor, tensor) outs(%C : tensor) -> (tensor) return %0 : tensor @@ -35,7 +35,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims ["thread_y", "thread_x"]) } } @@ -177,7 +177,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims ["thread_x"]) } } // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -628,7 +628,7 @@ // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} tensor.parallel_insert_slice %8 into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> } - } {thread_dim_mapping = []} + } {mapping = []} return %4 : tensor<320xf32> } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -128,7 +128,7 @@ tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } - // CHECK: } {thread_dim_mapping = [5]} - } {thread_dim_mapping = [5]} + // CHECK: } {mapping = ["none"]} + } {mapping = ["none"]} return } diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -338,12 +338,12 @@ %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread - // CHECK-NEXT: } {thread_dim_mapping = [42]} + // CHECK-NEXT: } {mapping = ["none"]} // CHECK-NEXT: return scf.foreach_thread (%thread_idx) in (%num_threads) { scf.foreach_thread.perform_concurrently { } - } {thread_dim_mapping = [42]} + } {mapping = ["none"]} return } diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -217,7 +217,7 @@ Location loc = op.getLoc(); auto foreachOp = rewriter.create( loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(), - /*threadDimMapping=*/ArrayRef{}, + /*mapping=*/ArrayRef{}, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { unsigned numThreadIdRegionArgs = helper.getIterationSpaceSizes().size();