diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -53,6 +53,11 @@ /// 'gpu.kernel' attribute. static bool isKernel(Operation *op); + /// Returns the number of workgroup (thread, block) dimensions supported in + /// the GPU dialect. + // TODO(zinenko,herhut): consider generalizing this. + static unsigned getNumWorkgroupDimensions() { return 3; } + /// Returns the numeric value used to identify the workgroup memory address /// space. static unsigned getWorkgroupAddressSpace() { return 3; } diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -117,6 +117,10 @@ ]; let extraClassDeclaration = [{ + /// Adds a workgroup attribution to "op" of the MemRef type with the given + /// shape and element type. + Value addWorkgroupAttribution(ArrayRef shape, Type elementType); + /// Returns `true` if the GPU function defined by this Op is a kernel, i.e. /// it is intended to be launched from host. bool isKernel() { diff --git a/mlir/include/mlir/Dialect/GPU/MemoryPromotion.h b/mlir/include/mlir/Dialect/GPU/MemoryPromotion.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/MemoryPromotion.h @@ -0,0 +1,29 @@ +//===- MemoryPromotion.h - Utilities for moving data across GPU -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares the utility functions that generate IR copying +// the data between different levels of memory hierarchy. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_GPU_MEMORYPROMOTION_H +#define MLIR_DIALECT_GPU_MEMORYPROMOTION_H + +namespace mlir { + +namespace gpu { +class GPUFuncOp; +} + +/// Promotes a function argument to workgroup memory in the given function. The +/// copies will be inserted in the beginning and in the end of the function. +void promoteToWorkgroupMemory(gpu::GPUFuncOp op, unsigned arg); + +} // end namespace mlir + +#endif // MLIR_DIALECT_GPU_MEMORYPROMOTION_H diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -79,6 +79,11 @@ /// Add one value to the argument list. BlockArgument addArgument(Type type); + /// Insert one value to the position in the argument list indicated by the + /// given iterator. The existing arguments are shifted. The block is expected + /// not to have predecessors. + BlockArgument insertArgument(args_iterator it, Type type); + /// Add one argument to the argument list for each type specified in the list. iterator_range addArguments(ArrayRef types); diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -2,9 +2,23 @@ IR/GPUDialect.cpp IR/DialectRegistration.cpp Transforms/KernelOutlining.cpp + Transforms/MemoryPromotion.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU ) -add_dependencies(MLIRGPU MLIRGPUOpsIncGen MLIRIR MLIRLLVMIR LLVMSupport) -target_link_libraries(MLIRGPU MLIRIR MLIRLLVMIR MLIRStandardOps LLVMSupport) +add_dependencies(MLIRGPU + MLIRGPUOpsIncGen + MLIREDSC + MLIRIR + MLIRLLVMIR + MLIRLoopOps + MLIRSupport + LLVMSupport) +target_link_libraries(MLIRGPU + MLIREDSC + MLIRIR + MLIRLLVMIR + MLIRLoopOps + MLIRSupport + LLVMSupport) diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -593,6 +593,24 @@ // GPUFuncOp //===----------------------------------------------------------------------===// +/// Adds a workgroup attribution to "op" of the MemRef type with the given shape +/// and element type. +Value GPUFuncOp::addWorkgroupAttribution(ArrayRef shape, + Type elementType) { + unsigned pos = getNumFuncArguments() + getNumWorkgroupAttributions(); + Block &bodyBlock = body().front(); + Value attribution = bodyBlock.insertArgument( + std::next(bodyBlock.args_begin(), pos), + MemRefType::get(shape, elementType, /*affineMapComposition=*/{}, + GPUDialect::getWorkgroupAddressSpace())); + auto numWorkgroupBuffersAttr = + getAttrOfType(getNumWorkgroupAttributionsAttrName()); + setAttr(getNumWorkgroupAttributionsAttrName(), + IntegerAttr::get(numWorkgroupBuffersAttr.getType(), + numWorkgroupBuffersAttr.getValue() + 1)); + return attribution; +} + void GPUFuncOp::build(Builder *builder, OperationState &result, StringRef name, FunctionType type, ArrayRef workgroupAttributions, ArrayRef privateAttributions, diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -0,0 +1,214 @@ +//===- MemoryPromotion.cpp - Utilities for moving data across GPU memories ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities that allow one to create IR moving the data +// across different levels of the GPU memory hierarchy. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/MemoryPromotion.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/Functional.h" + +using namespace mlir; +using namespace mlir::gpu; + +namespace { +/// Container structure for the bounds and steps of loops in a nest. +struct InsertCopiesLoopBounds { + SmallVector lbs; + SmallVector ubs; + SmallVector steps; + + void reserve(size_t size) { + lbs.reserve(size); + ubs.reserve(size); + steps.reserve(size); + } +}; + +/// Simple cache around a builder and a folder that uses the op folding +/// infrastructure to avoid the emission of duplicate constants. +class ConstantGenerator { +public: + ConstantGenerator(OpBuilder &b, Location l) + : builder(b), loc(l), folder(b.getContext()) {} + + Value get(int64_t value) { + return folder.create(builder, loc, value); + } + +private: + OpBuilder &builder; + Location loc; + OperationFolder folder; +}; +} // namespace + +/// Returns the textual name of a GPU dimension. +static StringRef getDimName(unsigned dim) { + if (dim == 0) + return "x"; + if (dim == 1) + return "y"; + if (dim == 2) + return "z"; + + llvm_unreachable("dimension ID overflow"); +} + +/// Emits the operations necessary to construct the bounds of a loops +/// surrounding the copies to or from the `type`. The bounds are stored in the +/// first argument for future use. The caller must ensure the loops operations +/// are dominated by any operation this emitted. +static void emitLoopBounds(InsertCopiesLoopBounds &bounds, OpBuilder &builder, + Location loc, MemRefType type) { + unsigned rank = type.getRank(); + unsigned numRequiredLoops = GPUDialect::getNumWorkgroupDimensions(); + bounds.reserve(rank); + ConstantGenerator constants(builder, loc); + + // Create the loops for dimensions 4+ of the memref. These are simple loops + // stepping from 0 to the number of elements in the corresponding dimension. + if (type.getRank() > numRequiredLoops) { + for (int64_t s : type.getShape().drop_back(numRequiredLoops)) { + bounds.lbs.push_back(constants.get(0)); + bounds.ubs.push_back(constants.get(s)); + bounds.steps.push_back(constants.get(1)); + } + } + + // Always create the loops for thread dimensions. If the rank of the memref + // is not large enough, create dummy loops anyway that will be essentially + // used as conditions for the first thread only to execute the work. + SmallVector innermostShape(numRequiredLoops, /*Value=*/1); + for (unsigned i = 0, + e = std::min(rank, GPUDialect::getNumWorkgroupDimensions()); + i < e; ++i) + innermostShape[numRequiredLoops - 1 - i] = type.getShape()[rank - i - 1]; + + // Attach loops to dimensions in reverse order so that the "x" dimension + // iterates over the fastest varying dimension of the memref, leading to + // memory access coalescing. These loops have the form + // + // for %i = thread_id() to step block_dim() + // + // so that each block of threads copies the corresponding block of elements + // and steps over to next block. + auto indexType = builder.getIndexType(); + for (unsigned i = 0; i < numRequiredLoops; ++i) { + auto dimName = builder.getStringAttr(getDimName(numRequiredLoops - i - 1)); + bounds.lbs.push_back( + builder.create(loc, indexType, dimName)); + bounds.ubs.push_back(constants.get(innermostShape[i])); + bounds.steps.push_back( + builder.create(loc, indexType, dimName)); + } +} + +/// Emits the loop nest performing the copy between "from" and "to" values with +/// loop bounds specified as arguments. +static void insertCopyLoops(OpBuilder &builder, Location loc, Value from, + Value to, ArrayRef lbs, ArrayRef ubs, + ArrayRef steps) { + assert(lbs.size() == ubs.size()); + assert(lbs.size() == steps.size()); + + unsigned rank = from->getType().cast().getRank(); + + // Delegate the rank-polymorphic loop nest construction to EDSC. + edsc::ScopedContext context(builder, loc); + auto ivs = edsc::makeIndexHandles(lbs.size()); + auto ivPtrs = + edsc::makeHandlePointers(MutableArrayRef(ivs)); + edsc::LoopNestBuilder( + ivPtrs, edsc::makeValueHandles(lbs), edsc::makeValueHandles(ubs), + edsc::makeValueHandles(steps))([&ivs, from, to, rank]() { + auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank); + edsc::StdIndexedValue fromHandle(from); + edsc::StdIndexedValue toHandle(to); + toHandle(activeIvs) = fromHandle(activeIvs); + }); +} + +/// Emits the loop nests performing the copy to the designated location in the +/// beginning of the region, and from the designated location immediately before +/// the terminator of the first block of the region. The region is expected to +/// have one block. This boils down to the following structure +/// +/// ^bb(...): +/// +/// for %arg0 = ... to ... step ... { +/// ... +/// for %argN = to ... step { +/// %0 = load %from[%arg0, ..., %argN] +/// store %0, %to[%arg0, ..., %argN] +/// } +/// ... +/// } +/// gpu.barrier +/// <... original body ...> +/// gpu.barrier +/// for %arg0 = ... to ... step ... { +/// ... +/// for %argN = to ... step { +/// %1 = load %to[%arg0, ..., %argN] +/// store %1, %from[%arg0, ..., %argN] +/// } +/// ... +/// } +/// +/// Inserts the barriers unconditionally since different threads may be copying +/// values and reading them. An analysis would be required to eliminate barriers +/// in case where value is only used by the thread that copies it. Both copies +/// are inserted unconditionally, an analysis would be required to only copy +/// live-in and live-out values when necessary. This copies the entire memref +/// pointed to by "from". In case a smaller block would be sufficient, the +/// caller can create a subview of the memref and promote it instead. +static void insertCopies(Region ®ion, Location loc, Value from, Value to) { + auto fromType = from.getType().cast(); + auto toType = to.getType().cast(); + (void)toType; + assert(fromType.getShape() == toType.getShape()); + assert(fromType.getRank() != 0); + assert(has_single_element(region) && + "unstructured control flow not supported"); + + OpBuilder builder(region.getContext()); + builder.setInsertionPointToStart(®ion.front()); + + InsertCopiesLoopBounds bounds; + emitLoopBounds(bounds, builder, loc, fromType); + + insertCopyLoops(builder, loc, from, to, bounds.lbs, bounds.ubs, bounds.steps); + builder.create(loc); + + builder.setInsertionPoint(®ion.front().back()); + builder.create(loc); + insertCopyLoops(builder, loc, to, from, bounds.lbs, bounds.ubs, bounds.steps); +} + +/// Promotes a function argument to workgroup memory in the given function. The +/// copies will be inserted in the beginning and in the end of the function. +void mlir::promoteToWorkgroupMemory(GPUFuncOp op, unsigned arg) { + Value value = op.getArgument(arg); + auto type = value.getType().dyn_cast(); + assert(type && type.hasStaticShape() && "can only promote memrefs"); + + Value attribution = + op.addWorkgroupAttribution(type.getShape(), type.getElementType()); + + // Replace the uses first since only the original uses are currently present. + // Then insert the copies. + value.replaceAllUsesWith(attribution); + insertCopies(op.getBody(), op.getLoc(), value, attribution); +} diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -179,6 +179,20 @@ } } +/// Insert one value to the given position of the argument list. The existing +/// arguments are shifted. The block is expected not to have predecessors. +BlockArgument Block::insertArgument(args_iterator it, Type type) { + assert(llvm::empty(getPredecessors()) && + "cannot insert arguments to blocks with predecessors"); + + // Use the args_iterator (on the BlockArgListType) to compute the insertion + // iterator in the underlying argument storage. + size_t distance = std::distance(args_begin(), it); + auto arg = BlockArgument::create(type, this); + arguments.insert(std::next(arguments.begin(), distance), arg); + return arg; +} + //===----------------------------------------------------------------------===// // Terminator management //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/GPU/promotion.mlir b/mlir/test/Dialect/GPU/promotion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/promotion.mlir @@ -0,0 +1,117 @@ +// RUN: mlir-opt -test-gpu-memory-promotion -split-input-file %s | FileCheck %s + +module @foo attributes {gpu.kernel_module} { + // Verify that the attribution was indeed introduced + // CHECK-LABEL: @memref3d + // CHECK-SAME: (%[[arg:.*]]: memref<5x4xf32> + // CHECK-SAME: workgroup(%[[promoted:.*]] : memref<5x4xf32, 3>) + gpu.func @memref3d(%arg0: memref<5x4xf32> {gpu.test_promote_workgroup}) kernel { + // Verify that loop bounds are emitted, the order does not matter. + // CHECK-DAG: %[[c1:.*]] = constant 1 + // CHECK-DAG: %[[c4:.*]] = constant 4 + // CHECK-DAG: %[[c5:.*]] = constant 5 + // CHECK-DAG: %[[tx:.*]] = "gpu.thread_id"() {dimension = "x"} + // CHECK-DAG: %[[ty:.*]] = "gpu.thread_id"() {dimension = "y"} + // CHECK-DAG: %[[tz:.*]] = "gpu.thread_id"() {dimension = "z"} + // CHECK-DAG: %[[bdx:.*]] = "gpu.block_dim"() {dimension = "x"} + // CHECK-DAG: %[[bdy:.*]] = "gpu.block_dim"() {dimension = "y"} + // CHECK-DAG: %[[bdz:.*]] = "gpu.block_dim"() {dimension = "z"} + + // Verify that loop loops for the copy are emitted, including the loop for + // dimension "z" that does not correspond to a memref dimension. + // CHECK: loop.for %[[i0:.*]] = %[[tz]] to %[[c1]] step %[[bdz]] + // CHECK: loop.for %[[i1:.*]] = %[[ty]] to %[[c5]] step %[[bdy]] + // CHECK: loop.for %[[i2:.*]] = %[[tx]] to %[[c4]] step %[[bdx]] + + // Verify that the copy is emitted and uses only the last two loops. + // CHECK: %[[v:.*]] = load %[[arg]][%[[i1]], %[[i2]]] + // CHECK: store %[[v]], %[[promoted]][%[[i1]], %[[i2]]] + + // Verify that the use has been rewritten. + // CHECK: "use"(%[[promoted]]) : (memref<5x4xf32, 3>) + "use"(%arg0) : (memref<5x4xf32>) -> () + + + // Verify that loop loops for the copy are emitted, including the loop for + // dimension "z" that does not correspond to a memref dimension. + // CHECK: loop.for %[[i0:.*]] = %[[tz]] to %[[c1]] step %[[bdz]] + // CHECK: loop.for %[[i1:.*]] = %[[ty]] to %[[c5]] step %[[bdy]] + // CHECK: loop.for %[[i2:.*]] = %[[tx]] to %[[c4]] step %[[bdx]] + + // Verify that the copy is emitted and uses only the last two loops. + // CHECK: %[[v:.*]] = load %[[promoted]][%[[i1]], %[[i2]]] + // CHECK: store %[[v]], %[[arg]][%[[i1]], %[[i2]]] + gpu.return + } +} + +// ----- + +module @foo attributes {gpu.kernel_module} { + // Verify that the attribution was indeed introduced + // CHECK-LABEL: @memref5d + // CHECK-SAME: (%[[arg:.*]]: memref<8x7x6x5x4xf32> + // CHECK-SAME: workgroup(%[[promoted:.*]] : memref<8x7x6x5x4xf32, 3>) + gpu.func @memref5d(%arg0: memref<8x7x6x5x4xf32> {gpu.test_promote_workgroup}) kernel { + // Verify that loop bounds are emitted, the order does not matter. + // CHECK-DAG: %[[c0:.*]] = constant 0 + // CHECK-DAG: %[[c1:.*]] = constant 1 + // CHECK-DAG: %[[c4:.*]] = constant 4 + // CHECK-DAG: %[[c5:.*]] = constant 5 + // CHECK-DAG: %[[c6:.*]] = constant 6 + // CHECK-DAG: %[[c7:.*]] = constant 7 + // CHECK-DAG: %[[c8:.*]] = constant 8 + // CHECK-DAG: %[[tx:.*]] = "gpu.thread_id"() {dimension = "x"} + // CHECK-DAG: %[[ty:.*]] = "gpu.thread_id"() {dimension = "y"} + // CHECK-DAG: %[[tz:.*]] = "gpu.thread_id"() {dimension = "z"} + // CHECK-DAG: %[[bdx:.*]] = "gpu.block_dim"() {dimension = "x"} + // CHECK-DAG: %[[bdy:.*]] = "gpu.block_dim"() {dimension = "y"} + // CHECK-DAG: %[[bdz:.*]] = "gpu.block_dim"() {dimension = "z"} + + // Verify that loop loops for the copy are emitted. + // CHECK: loop.for %[[i0:.*]] = %[[c0]] to %[[c8]] step %[[c1]] + // CHECK: loop.for %[[i1:.*]] = %[[c0]] to %[[c7]] step %[[c1]] + // CHECK: loop.for %[[i2:.*]] = %[[tz]] to %[[c6]] step %[[bdz]] + // CHECK: loop.for %[[i3:.*]] = %[[ty]] to %[[c5]] step %[[bdy]] + // CHECK: loop.for %[[i4:.*]] = %[[tx]] to %[[c4]] step %[[bdx]] + + // Verify that the copy is emitted. + // CHECK: %[[v:.*]] = load %[[arg]][%[[i0]], %[[i1]], %[[i2]], %[[i3]], %[[i4]]] + // CHECK: store %[[v]], %[[promoted]][%[[i0]], %[[i1]], %[[i2]], %[[i3]], %[[i4]]] + + // Verify that the use has been rewritten. + // CHECK: "use"(%[[promoted]]) : (memref<8x7x6x5x4xf32, 3>) + "use"(%arg0) : (memref<8x7x6x5x4xf32>) -> () + + // Verify that loop loops for the copy are emitted. + // CHECK: loop.for %[[i0:.*]] = %[[c0]] to %[[c8]] step %[[c1]] + // CHECK: loop.for %[[i1:.*]] = %[[c0]] to %[[c7]] step %[[c1]] + // CHECK: loop.for %[[i2:.*]] = %[[tz]] to %[[c6]] step %[[bdz]] + // CHECK: loop.for %[[i3:.*]] = %[[ty]] to %[[c5]] step %[[bdy]] + // CHECK: loop.for %[[i4:.*]] = %[[tx]] to %[[c4]] step %[[bdx]] + + // Verify that the copy is emitted. + // CHECK: %[[v:.*]] = load %[[promoted]][%[[i0]], %[[i1]], %[[i2]], %[[i3]], %[[i4]]] + // CHECK: store %[[v]], %[[arg]][%[[i0]], %[[i1]], %[[i2]], %[[i3]], %[[i4]]] + gpu.return + } +} + +// ----- + +module @foo attributes {gpu.kernel_module} { + // Check that attribution insertion works fine. + // CHECK-LABEL: @insert + // CHECK-SAME: (%{{.*}}: memref<4xf32> + // CHECK-SAME: workgroup(%{{.*}}: memref<1x1xf64, 3> + // CHECK-SAME: %[[wg2:.*]] : memref<4xf32, 3>) + // CHECK-SAME: private(%{{.*}}: memref<1x1xi64, 5>) + gpu.func @insert(%arg0: memref<4xf32> {gpu.test_promote_workgroup}) + workgroup(%arg1: memref<1x1xf64, 3>) + private(%arg2: memref<1x1xi64, 5>) + kernel { + // CHECK: "use"(%[[wg2]]) + "use"(%arg0) : (memref<4xf32>) -> () + gpu.return + } +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ TestCallGraph.cpp TestConstantFold.cpp TestLoopFusion.cpp + TestGpuMemoryPromotion.cpp TestInlining.cpp TestLinalgTransforms.cpp TestLiveness.cpp @@ -26,6 +27,8 @@ target_link_libraries(MLIRTestTransforms MLIRAffineOps MLIRAnalysis + MLIREDSC + MLIRGPU MLIRLoopOps MLIRPass MLIRTestDialect diff --git a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp @@ -0,0 +1,40 @@ +//===- TestGPUMemoryPromotionPass.cpp - Test pass for GPU promotion -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass testing the utilities for moving data across +// different levels of the GPU memory hierarchy. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/MemoryPromotion.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Simple pass for testing the promotion to workgroup memory in GPU functions. +/// Promotes all arguments with "gpu.test_promote_workgroup" attribute. This +/// does not check whether the promotion is legal (e.g., amount of memory used) +/// or beneficial (e.g., makes previously uncoalesced loads coalesced). +class TestGpuMemoryPromotionPass + : public OperationPass { + void runOnOperation() override { + gpu::GPUFuncOp op = getOperation(); + for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) { + if (op.getArgAttrOfType(i, "gpu.test_promote_workgroup")) + promoteToWorkgroupMemory(op, i); + } + } +}; +} // end namespace + +static PassRegistration registration( + "test-gpu-memory-promotion", + "Promotes the annotated arguments of gpu.func to workgroup memory.");