diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h @@ -150,6 +150,12 @@ StringRef features, int optLevel); +/// Collect a set of patterns to decompose memrefs ops. +void populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns); + +/// Pass decomposes memref ops inside `gpu.launch` body. +std::unique_ptr createGpuDecomposeMemrefsPass(); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td @@ -37,4 +37,22 @@ let dependentDialects = ["mlir::gpu::GPUDialect"]; } +def GpuDecomposeMemrefsPass : Pass<"gpu-decompose-memrefs"> { + let summary = "Decomposes memref index computation into explicit ops."; + let description = [{ + This pass decomposes memref index computation into explicit computations on + sizes/strides, obtained from `memref.extract_memref_metadata` which it tries + to place outside of `gpu.launch` body. Memrefs are then reconstructed using + `memref.reinterpret_cast`. + This is needed for as some targets (SPIR-V) lower memrefs to bare pointers + and sizes/strides for dynamically-sized memrefs are not available inside + `gpu.launch`. + }]; + let constructor = "mlir::createGpuDecomposeMemrefsPass()"; + let dependentDialects = [ + "mlir::gpu::GPUDialect", "mlir::memref::MemRefDialect", + "mlir::affine::AffineDialect" + ]; +} + #endif // MLIR_DIALECT_GPU_PASSES diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -229,6 +229,16 @@ SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); +/// Compute linear index from provided strides and indices, assuming strided +/// layout. +/// Returns AffineExpr and list of values to apply to it, e.g.: +/// +/// auto &&[expr, values] = computeLinearIndex(...); +/// offset = affine::makeComposedFoldedAffineApply(builder, loc, expr, values); +std::pair> +computeLinearIndex(OpFoldResult sourceOffset, ArrayRef strides, + ArrayRef indices); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -47,14 +47,15 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/AllReduceLowering.cpp Transforms/AsyncRegionRewriter.cpp + Transforms/DecomposeMemrefs.cpp Transforms/GlobalIdRewriter.cpp Transforms/KernelOutlining.cpp Transforms/MemoryPromotion.cpp Transforms/ParallelLoopMapper.cpp - Transforms/ShuffleRewriter.cpp Transforms/SerializeToBlob.cpp Transforms/SerializeToCubin.cpp Transforms/SerializeToHsaco.cpp + Transforms/ShuffleRewriter.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemrefs.cpp @@ -0,0 +1,234 @@ +//===- DecomposeMemrefs.cpp - Decompose memrefs pass implementation -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements decompose memrefs pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_GPUDECOMPOSEMEMREFSPASS +#include "mlir/Dialect/GPU/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +static void setInsertionPointToStart(OpBuilder &builder, Value val) { + if (auto parentOp = val.getDefiningOp()) { + builder.setInsertionPointAfter(parentOp); + } else { + builder.setInsertionPointToStart(val.getParentBlock()); + } +} + +static bool isInsideLaunch(Operation *op) { + return op->getParentOfType(); +} + +static std::tuple> +getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, + ArrayRef subOffsets, + ArrayRef subStrides = std::nullopt) { + auto sourceType = cast(source.getType()); + auto sourceRank = static_cast(sourceType.getRank()); + + memref::ExtractStridedMetadataOp newExtractStridedMetadata; + { + OpBuilder::InsertionGuard g(rewriter); + setInsertionPointToStart(rewriter, source); + newExtractStridedMetadata = + rewriter.create(loc, source); + } + + auto &&[sourceStrides, sourceOffset] = getStridesAndOffset(sourceType); + + auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult { + return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal) + : rewriter.getIndexAttr(dim); + }; + + OpFoldResult origOffset = + getDim(sourceOffset, newExtractStridedMetadata.getOffset()); + ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides(); + + SmallVector origStrides; + origStrides.reserve(sourceRank); + + SmallVector strides; + strides.reserve(sourceRank); + + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + for (auto i : llvm::seq(0u, sourceRank)) { + OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]); + + if (!subStrides.empty()) { + strides.push_back(affine::makeComposedFoldedAffineApply( + rewriter, loc, s0 * s1, {subStrides[i], origStride})); + } + + origStrides.emplace_back(origStride); + } + + auto &&[expr, values] = + computeLinearIndex(origOffset, origStrides, subOffsets); + OpFoldResult finalOffset = + affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values); + return {newExtractStridedMetadata.getBaseBuffer(), finalOffset, strides}; +} + +static Value getFlatMemref(OpBuilder &rewriter, Location loc, Value source, + ValueRange offsets) { + SmallVector offsetsTemp = getAsOpFoldResult(offsets); + auto &&[base, offset, ignore] = + getFlatOffsetAndStrides(rewriter, loc, source, offsetsTemp); + auto retType = cast(base.getType()); + return rewriter.create(loc, retType, base, offset, + std::nullopt, std::nullopt); +} + +static bool needFlatten(Value val) { + auto type = cast(val.getType()); + return type.getRank() != 0; +} + +static bool checkLayout(Value val) { + auto type = cast(val.getType()); + return type.getLayout().isIdentity() || + isa(type.getLayout()); +} + +namespace { +struct FlattenLoad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::LoadOp op, + PatternRewriter &rewriter) const override { + if (!isInsideLaunch(op)) + return rewriter.notifyMatchFailure(op, "not inside gpu.launch"); + + Value memref = op.getMemref(); + if (!needFlatten(memref)) + return rewriter.notifyMatchFailure(op, "nothing to do"); + + if (!checkLayout(memref)) + return rewriter.notifyMatchFailure(op, "unsupported layout"); + + Location loc = op.getLoc(); + Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices()); + rewriter.replaceOpWithNewOp(op, flatMemref); + return success(); + } +}; + +struct FlattenStore : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::StoreOp op, + PatternRewriter &rewriter) const override { + if (!isInsideLaunch(op)) + return rewriter.notifyMatchFailure(op, "not inside gpu.launch"); + + Value memref = op.getMemref(); + if (!needFlatten(memref)) + return rewriter.notifyMatchFailure(op, "nothing to do"); + + if (!checkLayout(memref)) + return rewriter.notifyMatchFailure(op, "unsupported layout"); + + Location loc = op.getLoc(); + Value flatMemref = getFlatMemref(rewriter, loc, memref, op.getIndices()); + Value value = op.getValue(); + rewriter.replaceOpWithNewOp(op, value, flatMemref); + return success(); + } +}; + +struct FlattenSubview : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp op, + PatternRewriter &rewriter) const override { + if (!isInsideLaunch(op)) + return rewriter.notifyMatchFailure(op, "not inside gpu.launch"); + + Value memref = op.getSource(); + if (!needFlatten(memref)) + return rewriter.notifyMatchFailure(op, "nothing to do"); + + if (!checkLayout(memref)) + return rewriter.notifyMatchFailure(op, "unsupported layout"); + + Location loc = op.getLoc(); + SmallVector subOffsets = op.getMixedOffsets(); + SmallVector subSizes = op.getMixedSizes(); + SmallVector subStrides = op.getMixedStrides(); + auto &&[base, finalOffset, strides] = + getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides); + + auto srcType = cast(memref.getType()); + auto resultType = cast(op.getType()); + unsigned subRank = static_cast(resultType.getRank()); + + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + + SmallVector finalSizes; + finalSizes.reserve(subRank); + + SmallVector finalStrides; + finalStrides.reserve(subRank); + + for (auto i : llvm::seq(0u, static_cast(srcType.getRank()))) { + if (droppedDims.test(i)) + continue; + + finalSizes.push_back(subSizes[i]); + finalStrides.push_back(strides[i]); + } + + rewriter.replaceOpWithNewOp( + op, resultType, base, finalOffset, finalSizes, finalStrides); + return success(); + } +}; + +struct GpuDecomposeMemrefsPass + : public impl::GpuDecomposeMemrefsPassBase { + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + populateGpuDecomposeMemrefsPatterns(patterns); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateGpuDecomposeMemrefsPatterns(RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} + +std::unique_ptr mlir::createGpuDecomposeMemrefsPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -261,3 +261,44 @@ res.push_back((*it).getValue().getSExtValue()); return res; } + +// TODO: do we have any common utily for this? +static MLIRContext *getContext(OpFoldResult val) { + assert(val && "Invalid value"); + if (auto attr = dyn_cast(val)) { + return attr.getContext(); + } else { + return cast(val).getContext(); + } +} + +std::pair> +mlir::computeLinearIndex(OpFoldResult sourceOffset, + ArrayRef strides, + ArrayRef indices) { + assert(strides.size() == indices.size()); + auto sourceRank = static_cast(strides.size()); + + // Hold the affine symbols and values for the computation of the offset. + SmallVector values(2 * sourceRank + 1); + SmallVector symbols(2 * sourceRank + 1); + + bindSymbolsList(getContext(sourceOffset), MutableArrayRef{symbols}); + AffineExpr expr = symbols.front(); + values[0] = sourceOffset; + + for (unsigned i = 0; i < sourceRank; ++i) { + // Compute the stride. + OpFoldResult origStride = strides[i]; + + // Build up the computation of the offset. + unsigned baseIdxForDim = 1 + 2 * i; + unsigned subOffsetForDim = baseIdxForDim; + unsigned origStrideForDim = baseIdxForDim + 1; + expr = expr + symbols[subOffsetForDim] * symbols[origStrideForDim]; + values[subOffsetForDim] = indices[i]; + values[origStrideForDim] = origStride; + } + + return {expr, values}; +} diff --git a/mlir/test/Dialect/GPU/decompose-memrefs.mlir b/mlir/test/Dialect/GPU/decompose-memrefs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/decompose-memrefs.mlir @@ -0,0 +1,137 @@ +// RUN: mlir-opt -gpu-decompose-memrefs -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> +// CHECK: @decompose_store +// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] +// CHECK: gpu.launch +// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref +func.func @decompose_store(%arg0 : f32, %arg1 : memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %block_dim0 = memref.dim %arg1, %c0 : memref + %block_dim1 = memref.dim %arg1, %c1 : memref + %block_dim2 = memref.dim %arg1, %c2 : memref + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { + memref.store %arg0, %arg1[%tx, %ty, %tz] : memref + gpu.terminator + } + return +} + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4, s5, s6] -> (s0 + s1 * s2 + s3 * s4 + s5 * s6)> +// CHECK: @decompose_store_strided +// CHECK-SAME: (%[[VAL:.*]]: f32, %[[MEM:.*]]: memref>) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] +// CHECK: gpu.launch +// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[OFFSET]], %[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]], %[[STRIDES]]#2] +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref +// CHECK: memref.store %[[VAL]], %[[PTR]][] : memref +func.func @decompose_store_strided(%arg0 : f32, %arg1 : memref>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %block_dim0 = memref.dim %arg1, %c0 : memref> + %block_dim1 = memref.dim %arg1, %c1 : memref> + %block_dim2 = memref.dim %arg1, %c2 : memref> + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { + memref.store %arg0, %arg1[%tx, %ty, %tz] : memref> + gpu.terminator + } + return +} + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> +// CHECK: @decompose_load +// CHECK-SAME: (%[[MEM:.*]]: memref) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] +// CHECK: gpu.launch +// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [], strides: [] : memref to memref +// CHECK: %[[RES:.*]] = memref.load %[[PTR]][] : memref +// CHECK: "test.test"(%[[RES]]) : (f32) -> () +func.func @decompose_load(%arg0 : memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %block_dim0 = memref.dim %arg0, %c0 : memref + %block_dim1 = memref.dim %arg0, %c1 : memref + %block_dim2 = memref.dim %arg0, %c2 : memref + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { + %res = memref.load %arg0[%tx, %ty, %tz] : memref + "test.test"(%res) : (f32) -> () + gpu.terminator + } + return +} + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> +// CHECK: @decompose_subview +// CHECK-SAME: (%[[MEM:.*]]: memref) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] +// CHECK: gpu.launch +// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[STRIDES]]#0, %[[STRIDES]]#1, 1] +// CHECK: "test.test"(%[[PTR]]) : (memref>) -> () +func.func @decompose_subview(%arg0 : memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %block_dim0 = memref.dim %arg0, %c0 : memref + %block_dim1 = memref.dim %arg0, %c1 : memref + %block_dim2 = memref.dim %arg0, %c2 : memref + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { + %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [%c1, %c1, %c1] : memref to memref> + "test.test"(%res) : (memref>) -> () + gpu.terminator + } + return +} + +// ----- + +// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK: #[[MAP2:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s1 + s2 * s3 + s4)> +// CHECK: @decompose_subview_strided +// CHECK-SAME: (%[[MEM:.*]]: memref) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:3, %[[STRIDES:.*]]:3 = memref.extract_strided_metadata %[[MEM]] +// CHECK: gpu.launch +// CHECK-SAME: threads(%[[TX:.*]], %[[TY:.*]], %[[TZ:.*]]) in +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[STRIDES]]#0] +// CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#1] +// CHECK: %[[IDX2:.*]] = affine.apply #[[MAP2]]()[%[[TX]], %[[STRIDES]]#0, %[[TY]], %[[STRIDES]]#1, %[[TZ]]] +// CHECK: %[[PTR:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[IDX2]]], sizes: [%{{.*}}, %{{.*}}, %{{.*}}], strides: [%[[IDX]], %[[IDX1]], 4] +// CHECK: "test.test"(%[[PTR]]) : (memref>) -> () +func.func @decompose_subview_strided(%arg0 : memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %block_dim0 = memref.dim %arg0, %c0 : memref + %block_dim1 = memref.dim %arg0, %c1 : memref + %block_dim2 = memref.dim %arg0, %c2 : memref + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %block_dim0, %block_y = %block_dim1, %block_z = %block_dim2) { + %res = memref.subview %arg0[%tx, %ty, %tz] [%c2, %c2, %c2] [2, 3, 4] : memref to memref> + "test.test"(%res) : (memref>) -> () + gpu.terminator + } + return +}