diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -172,6 +172,8 @@ #include "mlir/Dialect/GPU/IR/GPUOpInterfaces.h.inc" +#include "mlir/Interfaces/DeviceMappingInterface.h" + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/GPU/IR/GPUOpsAttributes.h.inc" diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -16,6 +16,7 @@ include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/Dialect/GPU/IR/GPUBase.td" include "mlir/Dialect/GPU/IR/ParallelLoopMapperAttr.td" +include "mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt --- a/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/CMakeLists.txt @@ -4,3 +4,8 @@ add_public_tablegen_target(MLIRGPUTransformOpsIncGen) add_mlir_doc(GPUTransformOps GPUTransformOps Dialects/ -gen-op-doc) + +set(LLVM_TARGET_DEFINITIONS GPUDeviceMappingAttr.td) +mlir_tablegen(GPUDeviceMapperEnums.h.inc -gen-enum-decls) +mlir_tablegen(GPUDeviceMapperEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRGPUDeviceMapperEnumsGen) diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUDeviceMappingAttr.td @@ -0,0 +1,49 @@ +//===-- GPUDeviceMappingAttr.td - Attribute definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the attribute used to map loops to gpu. +// +//===----------------------------------------------------------------------===// + +#ifndef GPU_DEVICE_MAPPING_ATTR +#define GPU_DEVICE_MAPPING_ATTR + +include "mlir/Dialect/GPU/IR/GPUBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/Interfaces/DeviceMappingInterface.td" + +def DimX : I64EnumAttrCase<"DimX", 0, "x">; +def DimY : I64EnumAttrCase<"DimY", 1, "y">; +def DimZ : I64EnumAttrCase<"DimZ", 2, "z">; + +def ThreadsEnum : I64EnumAttr<"Threads", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUThreadMappingAttr + : GPU_Attr<"GPUThreadMapping", "thread", [ DeviceMappingAttrInterface ]> { + let parameters = (ins + EnumParameter:$thread + ); + let assemblyFormat = "`<` params `>`"; +} + +def BlocksEnum : I64EnumAttr<"Blocks", "threads for loop mapping", [ + DimX, DimY, DimZ]> { + let cppNamespace = "::mlir::gpu"; +} + +def GPUBlockMappingAttr : GPU_Attr<"GPUBlockMapping", "block", [ DeviceMappingAttrInterface ] > { + let parameters = (ins + EnumParameter:$block + ); + let assemblyFormat = "`<` params `>`"; +} + +#endif // GPU_DEVICE_MAPPING_ATTR diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" +include "mlir/Interfaces/DeviceMappingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" @@ -792,7 +793,7 @@ a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case). - If non-empty, the `thread_dim_mapping` is added as an attribute to the + If non-empty, the `mapping` is added as an attribute to the resulting `scf.foreach_thread`. #### Return modes @@ -832,7 +833,7 @@ Variadic:$tile_sizes, DefaultValuedAttr:$static_num_threads, DefaultValuedAttr:$static_tile_sizes, - OptionalAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs PDL_Operation:$foreach_thread_op, PDL_Operation:$tiled_op); @@ -867,7 +868,7 @@ `tile_sizes` custom($tile_sizes, $static_tile_sizes, "ShapedType::kDynamicSize")) - (`(` `mapped` `to` `dims` $thread_dim_mapping^ `)`)? attr-dict + (`(` `mapping` `=` $mapping^ `)`)? attr-dict }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -19,6 +19,7 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/X86Vector/Transforms.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/DialectConversion.h" @@ -436,14 +437,14 @@ FailureOr tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping = {}); + Optional threadDimMapping); /// Same as `tileToForeachThreadOp`, but calculate the number of threads /// required using the given tileSizes. FailureOr tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping = {}); + Optional threadDimMapping); /// All indices returned by IndexOp should be invariant with respect to /// tiling. Therefore, if an operation is tiled, we have to transform the diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -17,6 +17,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/RegionKindInterface.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/DeviceMappingInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -16,6 +16,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/IR/RegionKindInterface.td" +include "mlir/Interfaces/DeviceMappingInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -378,8 +379,8 @@ application per thread. Further lowerings are responsible for specifying how this is materialized on concrete hardware resources. - An optional thread_dim_mapping index array attribute specifies for each - virtual thread dimension, how it remaps 1-1 to a set of concrete processing + An optional mapping is an attribute array that specifies processing units + with their dimension, how it remaps 1-1 to a set of concrete processing element resources (e.g. a CUDA grid dimension or a level of concrete nested async parallelism). At this time, the specification is backend-dependent and is not verified by the op, beyond being an index array attribute. @@ -440,7 +441,7 @@ // ``` - Example with thread_dim_mapping attribute: + Example with mapping attribute: ```mlir // @@ -456,7 +457,7 @@ scf.foreach_thread.perform_concurrently { ... } - } { thread_dim_mapping = [1, 0] } + } { mapping = [#gpu.thread, #gpu.thread] } // Implicit synchronization point. // Sequential context. // @@ -480,7 +481,7 @@ }]; let arguments = (ins Variadic:$num_threads, Variadic:$outputs, - DefaultValuedAttr:$thread_dim_mapping); + OptionalAttr:$mapping); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -493,10 +494,10 @@ let builders = [ // Bodyless builder, outputs must be specified. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - CArg<"ArrayRef", "{}">:$thread_dim_mapping)>, + "Optional":$mapping)>, // Builder that takes a bodyBuilder lambda. OpBuilder<(ins "ValueRange":$outputs, "ValueRange":$num_threads, - "ArrayRef":$thread_dim_mapping, + "ArrayRef":$mapping, "function_ref":$bodyBuilder)> ]; let extraClassDeclaration = [{ @@ -535,14 +536,14 @@ } /// Return the thread indices in the order specified by the - /// thread_dim_mapping attribute. Return failure is - /// thread_dim_mapping is not a valid permutation. - FailureOr> getPermutedThreadIndices(); + /// given mapping argument. Return failure is + /// mapping is not a valid permutation. + FailureOr> getPermutedThreadIndices(ArrayRef mapping); /// Return the number of threads in the order specified by the - /// thread_dim_mapping attribute. - /// Return failure is thread_dim_mapping is not a valid permutation. - FailureOr> getPermutedNumThreads(OpBuilder &b); + /// given mapping argument. + /// Return failure is mapping is not a valid permutation. + FailureOr> getPermutedNumThreads(OpBuilder &b, ArrayRef mapping); // The ensureTerminator method generated by SingleBlockImplicitTerminator is // unaware of the fact that our terminator also needs a region to be diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -24,6 +24,14 @@ add_public_tablegen_target(MLIRDataLayoutInterfacesIncGen) add_dependencies(mlir-generic-headers MLIRDataLayoutInterfacesIncGen) +set(LLVM_TARGET_DEFINITIONS DeviceMappingInterface.td) +mlir_tablegen(DeviceMappingAttrInterface.h.inc -gen-attr-interface-decls) +mlir_tablegen(DeviceMappingAttrInterface.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(DeviceMappingAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(DeviceMappingAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRDeviceMappingInterfacesIncGen) +add_dependencies(mlir-generic-headers MLIRDeviceMappingInterfacesIncGen) + add_mlir_doc(DataLayoutInterfaces DataLayoutAttrInterface Interfaces/ diff --git a/mlir/include/mlir/Interfaces/DeviceMappingInterface.h b/mlir/include/mlir/Interfaces/DeviceMappingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/DeviceMappingInterface.h @@ -0,0 +1,22 @@ +//===- DeviceMappingInterface.h - Cast Interfaces for MLIR ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the cast interfaces defined in +// `DeviceMappingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEVICEMAPPINGINTERFACE_H +#define MLIR_DEVICEMAPPINGINTERFACE_H + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/DeviceMappingAttrInterface.h.inc" + +#endif // MLIR_DEVICEMAPPINGINTERFACE_H diff --git a/mlir/include/mlir/Interfaces/DeviceMappingInterface.td b/mlir/include/mlir/Interfaces/DeviceMappingInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/DeviceMappingInterface.td @@ -0,0 +1,43 @@ +//===- DeviceMappingInterface.td - Data layout interfaces --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interfaces for the device mapping specification for the loops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DEVICEMAPPINGINTERFACE +#define MLIR_DEVICEMAPPINGINTERFACE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Attribute interfaces +//===----------------------------------------------------------------------===// + +def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Attribute interface describing how to map a region to a processing unit. + + It is intended to be a generic mechanism for binding regions to execution + units of an actual or virtual device. Each device first expresses its own + mappings, and those mappings must implement this interface. These mappings + can be used by the device-specific code generators and the desired regions + can be connected to the given processing unit. + + Currently, `scf.foreach_thread` uses this interface to express the mapping + of the loops it contains to the GPU's parallelism units such as threads and + thread blocks. + }]; +} + +def DeviceMappingArrayAttr : + TypedArrayAttrBase { } + +#endif // MLIR_DEVICEMAPPINGINTERFACE diff --git a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/TransformOps/CMakeLists.txt @@ -3,10 +3,13 @@ ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU/TransformOps + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Interfaces DEPENDS MLIRGPUTransformOpsIncGen - + MLIRDeviceMappingInterfacesIncGen + MLIRGPUDeviceMapperEnumsGen + LINK_LIBS PUBLIC MLIRIR MLIRGPUTransforms diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Value.h" #include "llvm/ADT/None.h" @@ -166,8 +168,20 @@ // Step 0. Outline the compute workload region and set up the workload // operands. + SmallVector mapping; + if (!foreachThreadOp.getMapping().has_value()) + return transformOp.emitSilenceableError() << "mapping must be present"; + for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) { + if (auto blockMap = map.dyn_cast()) { + mapping.push_back((int64_t)blockMap.getBlock()); + } else { + return transformOp.emitSilenceableError() + << "mapping must be #gpu.block"; + } + } + FailureOr> potentialGridDim = - foreachThreadOp.getPermutedNumThreads(rewriter); + foreachThreadOp.getPermutedNumThreads(rewriter, mapping); if (failed(potentialGridDim) || llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) { @@ -193,7 +207,7 @@ // Step 2. RAUW thread indices to thread ops. SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(); + *foreachThreadOp.getPermutedThreadIndices(mapping); assert(blockOps.size() == 3 && "3 block id ops are required"); for (auto [blockIdx, blockOp] : llvm::zip(threadIndices, blockOps)) { Value val = blockIdx; @@ -230,7 +244,8 @@ } /// This is a helper that is only used in -/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id. +/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects +/// block_id. static void generateGpuBlockIds(RewriterBase &rewriter, scf::ForeachThreadOp foreachOp, SmallVectorImpl &blockOps) { @@ -295,6 +310,7 @@ rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim, transformOp); if (diag.succeeded()) { + gridDim.resize(3, 1); diag = alterGpuLaunch(rewriter, gpuLaunch, cast(getOperation()), gridDim[0], gridDim[1], gridDim[2]); @@ -335,7 +351,18 @@ return failureHelper( "scf.foreach_thread with rank > 3 does not lower to gpu.thread_id"); - auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter); + SmallVector mapping; + if (!foreachThreadOp.getMapping().has_value()) + return failureHelper("mapping must be present"); + for (DeviceMappingAttrInterface map : *foreachThreadOp.getMapping()) { + if (auto threadMap = map.dyn_cast()) { + mapping.push_back((int64_t)threadMap.getThread()); + } else { + return failureHelper("mapping must be #gpu.thread"); + } + } + auto potentialBlockDim = + foreachThreadOp.getPermutedNumThreads(rewriter, mapping); if (failed(potentialBlockDim) || llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) { return !getConstantIntValue(ofr).has_value(); @@ -347,6 +374,7 @@ llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) { return getConstantIntValue(ofr).value(); })); + blockDim.resize(3, 1); // Step 1. Create the gpu.thread ops Location loc = foreachThreadOp.getLoc(); @@ -365,7 +393,8 @@ if (blockDim > globalBlockDim) { return failureHelper( "The requested GPU threads are fewer than the number of loop trip " - "counts. Try to tile scf.foreach_thread before mapping or set small " + "counts. Try to tile scf.foreach_thread before mapping or set " + "small " "blockDim."); } if (blockDim == globalBlockDim) @@ -400,7 +429,7 @@ // Step 4. RAUW thread indices to thread ops. SmallVector threadIndices = - *foreachThreadOp.getPermutedThreadIndices(); + *foreachThreadOp.getPermutedThreadIndices(mapping); for (auto [threadIdx, threadOp] : llvm::zip(threadIndices, threadOps)) { Value val = threadIdx; Value op = threadOp; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -19,6 +19,9 @@ #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" @@ -1457,19 +1460,13 @@ return diag; } rewriter.setInsertionPoint(tilableOp); - auto maybeThreadDimMappingAttr = threadDimMapping; - auto dimMapping = llvm::to_vector( - maybeThreadDimMappingAttr - ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) - : ArrayRef{}); - FailureOr tilingResult = failure(); if (!mixedNumThreads.empty()) { - tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, - numThreads, dimMapping); + tilingResult = linalg::tileToForeachThreadOp( + rewriter, tilableOp, numThreads, threadDimMapping); } else { tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( - rewriter, tilableOp, tileSizes, dimMapping); + rewriter, tilableOp, tileSizes, threadDimMapping); } if (failed(tilingResult)) @@ -1494,7 +1491,7 @@ DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( rewriter, state, cast(getOperation()), targets, - getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, + getMixedNumThreads(), getMixedTileSizes(), getMapping(), tileOps, tiledOps); if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -24,6 +24,8 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" @@ -226,7 +228,7 @@ static FailureOr tileToForeachThreadOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, Optional> nominalTileSizes, - ArrayRef threadDimMapping, bool omitTileOffsetBoundsCheck) { + Optional threadDimMapping, bool omitTileOffsetBoundsCheck) { Location loc = op->getLoc(); OpBuilder::InsertionGuard g(b); SmallVector loopRanges = op.getIterationDomain(b); @@ -363,7 +365,7 @@ FailureOr linalg::tileToForeachThreadOp(RewriterBase &b, TilingInterface op, ArrayRef numThreads, - ArrayRef threadDimMapping) { + Optional threadDimMapping) { return tileToForeachThreadOpImpl(b, op, numThreads, /*nominalTileSizes=*/None, threadDimMapping, /*omitTileOffsetBoundsCheck=*/false); @@ -372,7 +374,7 @@ FailureOr linalg::tileToForeachThreadOpUsingTileSizes( RewriterBase &b, TilingInterface op, ArrayRef tileSizes, - ArrayRef threadDimMapping) { + Optional threadDimMapping) { SmallVector loopRanges = op.getIterationDomain(b); unsigned nLoops = loopRanges.size(); SmallVector numThreads; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -13,12 +13,16 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DeviceMappingInterface.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::scf; @@ -1111,6 +1115,12 @@ if (body->getArgument(i + getRank()).getType() != getOutputs()[i].getType()) return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; + if (getMapping().has_value()) + for (auto map : getMapping().value()) { + if (!isa(map)) + return emitOpError() + << getMappingAttrName() << " is not device mapping attribute"; + } return success(); } @@ -1200,11 +1210,14 @@ void ForeachThreadOp::build(mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, ValueRange numThreads, - ArrayRef threadDimMapping) { + Optional threadDimMapping) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + if (threadDimMapping.has_value()) { + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + threadDimMapping.value()); + } + result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1231,12 +1244,12 @@ // Builder that takes a bodyBuilder lambda. void ForeachThreadOp::build( mlir::OpBuilder &builder, mlir::OperationState &result, ValueRange outputs, - ValueRange numThreads, ArrayRef threadDimMapping, + ValueRange numThreads, ArrayRef threadDimMapping, function_ref bodyBuilder) { result.addOperands(numThreads); result.addOperands(outputs); - result.addAttribute(ForeachThreadOp::getThreadDimMappingAttrName(result.name), - builder.getI64ArrayAttr(threadDimMapping)); + result.addAttribute(ForeachThreadOp::getMappingAttrName(result.name), + builder.getArrayAttr(threadDimMapping)); result.addAttribute( "operand_segment_sizes", builder.getDenseI32ArrayAttr({static_cast(numThreads.size()), @@ -1307,34 +1320,32 @@ template static FailureOr> getValuesPermutedByThreadMapping(scf::ForeachThreadOp foreachThreadOp, - const SmallVector &values) { + const SmallVector &values, + ArrayRef mapping) { // Apply mapping permutation if specified. - auto mapping = foreachThreadOp.getThreadDimMapping(); - if (mapping && !mapping.empty()) { - auto maybePermuted = permute(values, extractFromI64ArrayAttr(mapping)); - if (failed(maybePermuted)) - return foreachThreadOp->emitError("invalid permutation"); - return *maybePermuted; - } + auto maybePermuted = permute(values, mapping); + if (failed(maybePermuted)) + return foreachThreadOp->emitError("invalid permutation"); + return *maybePermuted; return values; } /// Return the thread indices in the order specified by the thread_dim_mapping /// attribute. Return failure is thread_dim_mapping is not a valid permutation. -FailureOr> ForeachThreadOp::getPermutedThreadIndices() { +FailureOr> +ForeachThreadOp::getPermutedThreadIndices(ArrayRef mapping) { SmallVector threadCountValues = this->getThreadIndices(); - threadCountValues.resize(3, Value()); - return getValuesPermutedByThreadMapping(*this, threadCountValues); + return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); } /// Return the number of threads in the order specified by the /// thread_dim_mapping attribute. /// Return failure is thread_dim_mapping is not a valid permutation. FailureOr> -ForeachThreadOp::getPermutedNumThreads(OpBuilder &b) { +ForeachThreadOp::getPermutedNumThreads(OpBuilder &b, + ArrayRef mapping) { SmallVector threadCountValues = this->getNumThreads(); - threadCountValues.resize(3, b.getIndexAttr(1)); - return getValuesPermutedByThreadMapping(*this, threadCountValues); + return getValuesPermutedByThreadMapping(*this, threadCountValues, mapping); } ForeachThreadOp mlir::scf::getForeachThreadOpThreadIndexOwner(Value val) { diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1141,10 +1141,11 @@ // Create new ForeachThreadOp without any results and drop the automatically // introduced terminator. rewriter.setInsertionPoint(foreachThreadOp); - auto newForeachThreadOp = rewriter.create( + ForeachThreadOp newForeachThreadOp; + newForeachThreadOp = rewriter.create( foreachThreadOp.getLoc(), /*outputs=*/ValueRange(), - foreachThreadOp.getNumThreads(), - extractFromI64ArrayAttr(foreachThreadOp.getThreadDimMapping())); + foreachThreadOp.getNumThreads(), foreachThreadOp.getMapping()); + newForeachThreadOp.getBody()->getTerminator()->erase(); // Move over block contents of the old op. diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -15,6 +15,7 @@ TilingInterface.cpp VectorInterfaces.cpp ViewLikeInterface.cpp + DeviceMappingInterface.cpp ) function(add_mlir_interface_library name) diff --git a/mlir/lib/Interfaces/DeviceMappingInterface.cpp b/mlir/lib/Interfaces/DeviceMappingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/DeviceMappingInterface.cpp @@ -0,0 +1,17 @@ +//===- DeviceMappingInterface.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/DeviceMappingInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Table-generated class definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/DeviceMappingAttrInterface.cpp.inc" diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -26,7 +26,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -38,7 +38,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -67,7 +67,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -79,7 +79,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -106,7 +106,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2 x 32 x f32> @@ -131,7 +131,7 @@ scf.foreach_thread (%i, %j, %k, %l) in (%c2, %c32,%c32,%c32) { %4 = memref.load %x[%i, %j, %k, %l] : memref<2x32x32x32xf32> memref.store %4, %y[%i, %j, %k, %l] : memref<2x32x32x32xf32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread, #gpu.thread] } gpu.terminator } return %y : memref<2x32x32x32xf32> @@ -197,14 +197,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } @@ -232,14 +232,14 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } scf.foreach_thread (%i, %j) in (%c7, %c9) { %4 = memref.load %x[%i, %j] : memref<2 x 32 x f32> %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } return %y : memref<2 x 32 x f32> } @@ -261,7 +261,7 @@ %5 = memref.load %y[%i, %j] : memref<2 x 32 x f32> %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : memref<2 x 32 x f32> - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.block, #gpu.block] } return %y : memref<2 x 32 x f32> } @@ -273,4 +273,3 @@ } // ----- - diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -24,7 +24,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.block, #gpu.block]} gpu.terminator } return %y : !type @@ -33,7 +33,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_foreach_to_blocks %funcop { blockDim = [12, 9, 1]} + transform.gpu.map_foreach_to_blocks %funcop { gridDim = [12, 9]} } // ----- @@ -73,12 +73,12 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread]} scf.foreach_thread (%i) in (%c12) { %7 = memref.load %t[%i] : !type1d %8 = arith.addf %alpha, %7 : f32 memref.store %8, %t[%i] : !type1d - } {thread_dim_mapping = [0, 1, 2]} + } {mapping = [#gpu.thread] } gpu.terminator } return %y : !type @@ -87,7 +87,7 @@ transform.sequence failures(propagate) { ^bb1(%arg0: !pdl.operation): %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 - transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9, 1] } + transform.gpu.map_nested_foreach_to_threads %funcop { blockDim = [12, 9] } } // ----- @@ -118,8 +118,8 @@ %5 = memref.load %y[%i, %j, %k, %l] : !type4d %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j, %k, %l] : !type4d - } {thread_dim_mapping = [1, 0, 2]} - } {thread_dim_mapping = [0, 1, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } + } { mapping = [#gpu.block, #gpu.block] } return %y : !type4d } @@ -151,7 +151,7 @@ %5 = memref.load %y[%i, %j] : !type %6 = math.fma %alpha, %4, %5 : f32 memref.store %6, %y[%i, %j] : !type - } {thread_dim_mapping = [1, 0, 2]} + } { mapping = [#gpu.thread, #gpu.thread] } gpu.terminator } return %y : !type diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -26,7 +26,7 @@ // CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C_BLK]]{{.*}} : // CHECK-SAME: tensor into tensor // CHECK-NEXT: } - // CHECK-NEXT: } {thread_dim_mapping = [1, 0]} + // CHECK-NEXT: } {mapping = [#gpu.thread, #gpu.thread]} %0 = linalg.matmul ins(%A, %B : tensor, tensor) outs(%C : tensor) -> (tensor) return %0 : tensor @@ -35,7 +35,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapped to dims [1, 0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [10, 20] (mapping = [ #gpu.thread, #gpu.thread ] ) } } @@ -177,7 +177,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !pdl.operation): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 - %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0]) + %1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] ( mapping = [#gpu.thread]) } } // CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)> diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir @@ -128,7 +128,7 @@ tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : tensor<1xf32> into tensor<100xf32> } - // CHECK: } {thread_dim_mapping = [5]} - } {thread_dim_mapping = [5]} + // CHECK: } {mapping = [#gpu.thread]} + } {mapping = [#gpu.thread]} return } diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -338,12 +338,12 @@ %num_threads = arith.constant 100 : index // CHECK: scf.foreach_thread - // CHECK-NEXT: } {thread_dim_mapping = [42]} + // CHECK-NEXT: } {mapping = [#gpu.thread]} // CHECK-NEXT: return scf.foreach_thread (%thread_idx) in (%num_threads) { scf.foreach_thread.perform_concurrently { } - } {thread_dim_mapping = [42]} + } {mapping = [#gpu.thread]} return } diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -217,7 +218,7 @@ Location loc = op.getLoc(); auto foreachOp = rewriter.create( loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(), - /*threadDimMapping=*/ArrayRef{}, + /*threadDimMapping=*/ArrayRef{}, [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { unsigned numThreadIdRegionArgs = helper.getIterationSpaceSizes().size();