diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -57,6 +57,16 @@ GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">, BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">; +def IsMMAFragmentTypePred : CPred<"$_self.isa<::mlir::gpu::MMAFragmentType>()">; + +def GPU_MMAFragment : DialectType< + GPU_Dialect, IsMMAFragmentTypePred, "mmafragment type">; + +class MMAFragmentOf allowedTypes> : + ContainerType, IsMMAFragmentTypePred, + "$_self.cast<::mlir::gpu::MMAFragmentType>().getElementType()", + "gpu.mmafragment", "::mlir::gpu::MMAFragmentType">; + def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> { let description = [{ Interface for GPU operations that execute asynchronously on the device. @@ -102,4 +112,34 @@ ]; } +// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing +// the layouts of the operands supported by the ops that use this attribute. +def RowMajor: StrEnumAttrCase<"RowMajor", 0>; +def ColMajor: StrEnumAttrCase<"ColMajor", 1>; + +// Specifies a String enum Attribute for Warp wide matrix operations, +// representing the layout of respective operands GPU Dialect. The layout +// later governs the lowerings to appropriate intrinsics. +def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", + [RowMajor, ColMajor]> { + let stringToSymbolFnName = "LayoutStrToEnum"; + let symbolToStringFnName = "EnumToLayoutStr"; +} + +// Cases of the String enum Attribute for SubgroupMmaOperand, representing +// the operand tags supported by the ops that use this attribute. +def DOp: StrEnumAttrCase<"DOp", 0>; +def Aop: StrEnumAttrCase<"AOp", 1>; +def Bop: StrEnumAttrCase<"BOp", 2>; +def Cop: StrEnumAttrCase<"COp", 3>; + +// Specifies a String enum Attribute for Warp wide matrix operations, +// representing the operands for the operations. The operands later +// govern the lowerings to appropriate intrinsics. +def SubgroupMmaOperand: StrEnumAttr<"WmmaOp", "Specifes which subgroup operand is this", + [DOp, Aop, Bop, Cop]> { + let stringToSymbolFnName = "OpStrToEnum"; + let symbolToStringFnName = "EnumToOpStr"; +} + #endif // GPU_BASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -43,6 +43,68 @@ using Base::Base; }; +/// MMAFragmentType storage and uniquing. +struct MMAFragmentStorageType : public TypeStorage { + MMAFragmentStorageType(int64_t size, Type elementType) + : size(size), elementType(elementType) {} + + /// The hash key for uniquing. + using KeyTy = std::pair; + bool operator==(const KeyTy &key) const { + return key == KeyTy(size, elementType); + } + + /// Construction. + static MMAFragmentStorageType *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + MMAFragmentStorageType(key.first, key.second); + } + + /// Number of elements held in the fragment. + int64_t size; + + /// Element type of elements held in the fragment. + Type elementType; +}; + +/// MMAFragment represents a fragment or collection of elements held by a thread +/// for matrix-matrix multiply accumulate operations. MMAFragments are taken as +/// direct operands by these operations and are also produced as results. There +/// fragments are meant to reside in the registers. A limited number of +/// pointwise operations can be performed on these fragments, i.e., operations +/// which operate uniformly on all the elements in the fragment and do not +/// change the order of matrix elements in the fragments. The above conditions +/// exist because the layout of matrix elemnets inside the fragment is opaque +/// i.e., the elements may be present in the fragment in any order. +class MMAFragmentType + : public Type::TypeBase { +public: + using Base::Base; + + /// Get MMAFragmentType and verify construction Invariants. + static MMAFragmentType get(int64_t shape, Type elementType); + + /// Get MMAFragmentType at a particular location and verify construction + /// Invariants. + static MMAFragmentType getChecked(Location loc, int64_t shape, + Type elementType); + + /// Check if a type is valid a MMAFragmentType elementType. + static bool isValidElementType(Type elementType); + + /// Verify that shape and elementType are actually allowed for the + /// MMAFragmentType. + static LogicalResult verifyConstructionInvariants(Location loc, int64_t shape, + Type elementType); + + /// Get size of MMAFragment in number of elements. + int64_t getSize() const; + + /// Get elementType of a single element in MMAFragment. + Type getElementType() const; +}; + // Adds a `gpu.async.token` to the front of the argument list. void addAsyncDependency(Operation *op, Value token); diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -909,4 +909,124 @@ let verifier = [{ return ::verify(*this); }]; } +def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", + [MemoryEffects<[MemRead]>]>{ + + let summary = "GPU warp synchronous matrix load"; + + let description = [{ + The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively + using all the threads in a subgroup. + + This operation takes a memref as argument. It is the source matrix from which + data is to be loaded. The op returns a `!gpu.mmafragment`. The source memref + can be in the global or shared memory space. The starting of the load address + is determined using indices provided. Which matrix is being loaded is specified + by the `operand` attribute. This attribute is necessary becasue there exists a + different LLVM intrinsic for loading each operand, This is probably because all + operands need to be laid out in a specific/different way for the operation in the + registers. `leadDimension` attribute specifies the leading dimension of the source + matrix. + + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and + `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + %0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {operand = "AOp", leadDimension = 32 + : i32} : memref<32x32xf16, 3>, !gpu.mmafragment<8xvector<2xf16>> + ``` + }]; + + let arguments = (ins Arg, "", [MemRead]>:$srcMemref, + Variadic:$indices, + IndexAttr:$leadDimension, + SubgroupMmaOperand:$operand); + + let results = (outs GPU_MMAFragment:$res); + + let assemblyFormat = [{ + $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", + [MemoryEffects<[MemWrite]>]>{ + + let summary = "GPU warp synchronous matrix store"; + + let description = [{ + The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively + using all the threads in a subgroup. + + This operation takes a `!gpu.mmafragment` and a memref as arguments. + `!gpu.mmafragment` is the source which contains the data to be stored. + The destination can be in the global or shared memory space. The starting + of store address is determined using indices provided. The `leadDimension` + attribute specifies the leading dimension of the destination matrix. + + This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and + `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} : + !gpu.mmafragment<4, vector<2xf16>>, memref<32x32xf16, 3> + ``` + }]; + + let arguments = (ins Arg]>>:$src, + Arg, "",[MemWrite]>:$dstMemref, + Variadic:$indices, + IndexAttr:$leadDimension); + + let assemblyFormat = [{ + $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{ + + let summary = "GPU warp synchronous matrix multiply accumulate"; + + let description = [{ + The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate(mma) + operation using all the threads in a subgroup. + + This operation takes three `!gpu.mmafragment`s as arguments. All of them hold `A`, + `B` and `C`operands for the mma operation. The operation performed is represented + as `D = A * B + C`. The op returns a `!gpu.mmafragment` which contains the result of + the operation held by the current thread. + + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and + `gpu.subgroup_mma_load_matrix`. + + Example: + + ```mlir + %D = gpu.subgroup_mma_compute_matrix %A, %B, %C : + !gpu.mmafragment<8xvector<2xf16>>, !gpu.mmafragment<8xvector<2xf16>>, + !gpu.mmafragment<4xvector<2xf16>> -> !gpu.mmafragment<4xvector<2xf16>> + ``` + }]; + + let arguments = (ins Arg]>>:$opA, + Arg]>>:$opB, + Arg]>>:$opC); + + let results = (outs GPU_MMAFragment:$res); + + let assemblyFormat = [{ + $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + #endif // GPU_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -144,4 +144,202 @@ let verifier = [{ return ::verify(*this); }]; } +class NVVM_WMMALoadOp : NVVM_Op, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins Variadic:$args)> { + + let summary = "Warp synchronous matrix load"; + + string baseDescription = [{"The `nvvm.wmma.m16n16k16.load.[a, b, c]` operation" + "loads a matrix collectively using all the threads in a warp." + + "The operation takes two arguments, the address from where the matrix" + "elements are to be loaded from and a stride. The stride argument" + "represents the leading dimension of the source matrix. The address and" + "the stride are required to be the same across all threads in the warp." + "Each thread in a warp holds a certain number of elements. The Op returns" + "a LLVMStruct which holds the elements of the matrix held by this thread." + + "This op is meant to be used along with `nvvm.wmma.m16n16k16.store` and" + "`nvvm.wmma.m16n16k16.mma`."}]; + + let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; +} + +def NVVM_WMMALoadAOp : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{ + + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMALoadBOp : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{ + + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMALoadCOp : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.c %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMAStoreOp : + NVVM_Op<"wmma.m16n16k16.store.d.f16.row.stride">, + Arguments<(ins Variadic:$args)>{ + string llvmBuilder = [{ + createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args); + }]; + + let summary = "Warp synchronous matrix store"; + + let description = [{ + The `nvvm.wmma.m16n16k16.store` operation stores a matrix collectively using + all the threads in a warp. + + The operation takes as arguments the address to where the matrix elements are + to be stored, a stride and the elements to store, held by the current thread. + The stride argument represents the leading dimension of the destination matrix. + The address and the stride are required to be the same across all threads in the + warp. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and + `nvvm.wmma.m16n16k16.mma`. + + Example: + + ```mlir + nvvm.wmma.m16n16k16.store %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10 : + !llvm.ptr, !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, + !llvm.i32 + ``` + }]; + + let assemblyFormat = "$args attr-dict `:` type($args)"; + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMAMmaOp : + NVVM_Op<"wmma.m16n16k16.mma.row.row.f16.f16">, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins Variadic:$args)>{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args); + }]; + + let summary = "Warp synchronous matrix-multiply accumulate using tensor cores."; + + let description = [{ + The `nvvm.wmma.m16n16k16.mma` operation performs a matrix-multiply accumulate + (mma) operation using all the threads in a warp. + + The operation performed is represented as `D = A * B + C`. The operation takes + as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the + current thread. The op returns a LLVM struct which holds a part of the result + held by the current thread. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma. + m16n16k16.store`. + + Example: + + ```mlir + %20 = nvvm.wmma.m16n16k16.mma %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, + %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct<(vector<2xf16>, + vector<2xf16>, vector<2xf16>, vector<2xf16>)> + ``` + }]; + + let parser = [{ + SmallVector operands; + ::llvm::SMLoc operandsLoc; + Type operandType; + Type resType; + + operandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(operands) + || parser.parseOptionalAttrDict(result.attributes) + || parser.parseColon() + || parser.parseType(operandType) + || parser.parseArrow()) + return failure(); + + unsigned numOperands = operands.size(); + SmallVector operandTypes(numOperands, operandType); + if (parser.parseType(resType)) + return failure(); + result.addTypes(resType); + if (parser.resolveOperands(operands, operandTypes, operandsLoc, result.operands)) + return failure(); + return success(); + }]; + + let printer = [{ + p << getOperationName(); + p << ' '; + p << args(); + p.printOptionalAttrDict(getAttrs(), {}); + p << " : "; + p << getOperation()->getOperand(0).getType(); + p << ' ' << "->"; + p << ' '; + p << ::llvm::ArrayRef<::mlir::Type>(res().getType()); + }]; + + let verifier = [{ return ::verify(*this); }]; +} + #endif // NVVMIR_OPS diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -28,10 +28,59 @@ using namespace mlir; using namespace mlir::gpu; +//===----------------------------------------------------------------------===// +// MMAFragmentType +//===----------------------------------------------------------------------===// + +MMAFragmentType MMAFragmentType::get(int64_t size, Type elementType) { + return Base::get(elementType.getContext(), size, elementType); +} + +MMAFragmentType MMAFragmentType::getChecked(Location loc, int64_t size, + Type elementType) { + return Base::getChecked(loc, size, elementType); +} + +int64_t MMAFragmentType::getSize() const { return getImpl()->size; } + +Type MMAFragmentType::getElementType() const { return getImpl()->elementType; } + +bool MMAFragmentType::isValidElementType(Type elementType) { + + if (auto vectorElemType = elementType.dyn_cast()) + return vectorElemType.getRank() == 1 && vectorElemType.getDimSize(0) == 2; + else + return false; +} + +LogicalResult MMAFragmentType::verifyConstructionInvariants(Location loc, + int64_t size, + Type elementType) { + if (size <= 0) + return emitError(loc, "MMAFragmentType size must be atleast one"); + + if (!MMAFragmentType::isValidElementType(elementType)) + return emitError(loc, "MMAFragmentType elements must be vector<2xf16>"); + + return success(); +} + //===----------------------------------------------------------------------===// // GPUDialect //===----------------------------------------------------------------------===// +/// GPU memory space identifiers. +enum GPUMemorySpace { + /// Generic memory space identifier. + kGenericMemorySpace = 0, + + /// Global memory space identifier. + kGlobalMemorySpace = 1, + + /// Shared memory space identifier. + kSharedMemorySpace = 3 +}; + bool GPUDialect::isKernel(Operation *op) { UnitAttr isKernelAttr = op->getAttrOfType(getKernelFuncAttrName()); return static_cast(isKernelAttr); @@ -39,6 +88,7 @@ void GPUDialect::initialize() { addTypes(); + addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/GPU/GPUOps.cpp.inc" @@ -56,6 +106,29 @@ if (keyword == "async.token") return AsyncTokenType::get(context); + if (keyword == "mmafragment") { + llvm::SMLoc beginLoc = parser.getNameLoc(); + + // Parse '<'. + if (parser.parseLess()) + return nullptr; + + // Parse the size and elementType. + int64_t size; + VectorType elementType; + if (parser.parseInteger(size) || parser.parseComma() || + parser.parseType(elementType)) { + return nullptr; + } + + // Parse '>'. + if (parser.parseGreater()) + return nullptr; + + return MMAFragmentType::getChecked(parser.getEncodedSourceLoc(beginLoc), + size, elementType); + } + parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); return Type(); } @@ -63,6 +136,10 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "async.token"; }) + .Case([&](MMAFragmentType fragTy) { + os << "mmafragment<" << fragTy.getSize() << ", " + << fragTy.getElementType() << ">"; + }) .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); } @@ -138,7 +215,8 @@ return walkResult.wasInterrupted() ? failure() : success(); } -template static LogicalResult verifyIndexOp(T op) { +template +static LogicalResult verifyIndexOp(T op) { auto dimension = op.dimension(); if (dimension != "x" && dimension != "y" && dimension != "z") return op.emitError("dimension \"") << dimension << "\" is invalid"; @@ -885,6 +963,91 @@ printer << "]"; } +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaLoadMatrixOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { + auto srcType = op.srcMemref().getType(); + auto srcMemrefType = srcType.cast(); + auto srcMemSpace = srcMemrefType.getMemorySpace(); + + if (!srcMemrefType.getAffineMaps().empty() && + !srcMemrefType.getAffineMaps().front().isIdentity()) + return op.emitError("expected identity layout map for source memref"); + + if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && + srcMemSpace != kGlobalMemorySpace) + return op.emitError( + "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " + "kGlobalMemorySpace only allowed"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaStoreMatrixOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaStoreMatrixOp op) { + auto srcType = op.src().getType(); + auto dstType = op.dstMemref().getType(); + auto srcFragType = srcType.cast(); + auto dstMemrefType = dstType.cast(); + auto dstMemSpace = dstMemrefType.getMemorySpace(); + + if (!dstMemrefType.getAffineMaps().empty() && + !dstMemrefType.getAffineMaps().front().isIdentity()) + return op.emitError("expected identity layout map for destination memref"); + + if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && + dstMemSpace != kGlobalMemorySpace) + return op.emitError( + "destination memorySpace of kGenericMemorySpace, " + "kGlobalMemorySpace or kSharedMemorySpace only allowed"); + + if (srcFragType.getSize() != 4) + return op.emitError( + "operand should be of type !gpu.mmafragment<4xvector<2xf16>>"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaComputeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaComputeOp op) { + enum OperandMap { A, B, C }; + SmallVector opTypes; + SmallVector elemTypes; + + auto populateOpInfo = [&opTypes, &elemTypes, &op]() { + opTypes.push_back(op.opA().getType().cast()); + opTypes.push_back(op.opB().getType().cast()); + opTypes.push_back(op.opC().getType().cast()); + + for (MMAFragmentType opType : opTypes) { + elemTypes.push_back(opType.getElementType()); + } + }; + populateOpInfo(); + + auto isValidElementType = [](Type &elemType, MMAFragmentType &fragTy, + unsigned numElems) { + return fragTy.getSize() == numElems; + }; + + if (!isValidElementType(elemTypes[A], opTypes[A], 8) || + !isValidElementType(elemTypes[B], opTypes[B], 8) || + !isValidElementType(elemTypes[C], opTypes[C], 4)) + return op.emitError( + "A and B must be of type !gpu.mmafragment<8xvector<2xf16>> and C " + "must of type !gpu.mmafragment<4xvector<2xf16>>"); + + return success(); +} + #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -94,12 +94,12 @@ auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - SmallVector operand_types(op.getOperandTypes().begin(), - op.getOperandTypes().end()); - if (operand_types != SmallVector(8, f16x2Ty) && - operand_types != SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty}) { + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector(8, f16x2Ty) && + operandTypes != SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty}) { return op.emitOpError( "expected operands to be 4 s followed by either " "4 s or 8 floats"); @@ -120,9 +120,9 @@ "\"row\" or \"col\""); } - if (operand_types == SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty} && + if (operandTypes == SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty} && op.getType() == f32x8StructTy && alayout.getValue() == "row" && blayout.getValue() == "col") { return success(); @@ -130,6 +130,99 @@ return op.emitOpError("unimplemented mma.sync variant"); } +template +static LogicalResult verifyWMMALoadOp(T op, StringRef operand) { + MLIRContext *context = op.getContext(); + auto i32Ty = IntegerType::get(context, 32); + auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); + auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); + auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); + auto f16Ty = FloatType::getF16(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral( + context, + {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector{i32Ptr1Ty, i32Ty} && + operandTypes != SmallVector{i32Ptr3Ty, i32Ty} && + operandTypes != SmallVector{i32Ptr0Ty, i32Ty}) { + return op.emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3 followed by ldm of the source"); + } + + if (operand.equals("AOp") || operand.equals("BOp")) { + if (op.getType() != f16x2x8StructTy) { + return op.emitOpError("expected result type of loadAOp and loadBOp to be " + "a struct of 8 s"); + } + } else if (operand.equals("COp")) { + if (op.getType() != f16x2x4StructTy) { + return op.emitOpError( + "expected result type of loadCOp to be a struct of 4 s"); + } + } + + return success(); +} + +static LogicalResult verify(WMMALoadAOp op) { + return verifyWMMALoadOp(op, "AOp"); +} + +static LogicalResult verify(WMMALoadBOp op) { + return verifyWMMALoadOp(op, "BOp"); +} + +static LogicalResult verify(WMMALoadCOp op) { + return verifyWMMALoadOp(op, "COp"); +} + +static LogicalResult verify(WMMAStoreOp op) { + MLIRContext *context = op.getContext(); + auto i32Ty = IntegerType::get(context, 32); + auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); + auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); + auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); + auto f16Ty = FloatType::getF16(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f16x2Ty, i32Ty} && + operandTypes != SmallVector{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f16x2Ty, i32Ty} && + operandTypes != SmallVector{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f16x2Ty, i32Ty}) { + return op.emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3 followed by ldm of the source"); + } + + return success(); +} + +static LogicalResult verify(WMMAMmaOp op) { + MLIRContext *context = op.getContext(); + auto f16Ty = FloatType::getF16(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector(20, f16x2Ty)) + return op.emitOpError("expected 20 s as operands"); + + if (op.getResult().getType() != f16x2x4StructTy) + return op.emitOpError("expected result type to be a struct of 4 s"); + + return success(); +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp @@ -35,6 +35,27 @@ return builder.CreateCall(fn, args); } +static llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilder<> &builder, + llvm::Intrinsic::ID intrinsic, + ArrayRef args = {}) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn; + if (llvm::Intrinsic::isOverloaded(intrinsic)) { + if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16) { + // NVVM load and store instrinsic names are overloaded on the + // source/destination pointer type. Pointer is the first argument in the + // corresponding NVVM Op. + fn = llvm::Intrinsic::getDeclaration(module, intrinsic, + {args[0]->getType()}); + } else { + fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {}); + } + } else { + fn = llvm::Intrinsic::getDeclaration(module, intrinsic); + } + return builder.CreateCall(fn, args); +} + static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, bool withPredicate) { if (withPredicate) { diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -458,3 +458,86 @@ // expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}} gpu.memcpy %dst, %src : memref<7xf32>, memref<9xf32> } + +// ----- + +func @mmfragment_invalid_elem_type(){ + %wg = alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + // expected-error @+1 {{MMAFragmentType elements must be vector<2xf16>}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {operand = "AOp", leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mmafragment<8, vector<4xf16>> + return +} + +// ----- + +func @mmfragment_invalid_size(){ + %wg = alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + // expected-error @+1 {{MMAFragmentType size must be atleast one}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {operand = "AOp", leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mmafragment<0, vector<4xf16>> + return +} + +// ----- +#layout_map_col_major = affine_map<(i, j) -> (j, i)> + +func @mmaLoadOp_identity_layout(){ + %wg = alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3> + %i = constant 16 : index + // expected-error @+1 {{expected identity layout map for source memref}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {operand = "AOp", leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mmafragment<8, vector<2xf16>> + return +} + +// ----- + +func @mmaLoadOp_invalid_mem_space(){ + %wg = alloca() {alignment = 32} : memref<32x32xf16, 5> + %i = constant 16 : index + // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {operand = "AOp", leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mmafragment<8, vector<2xf16>> + return +} + +// ----- +#layout_map_col_major = affine_map<(i, j) -> (j, i)> + +func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mmafragment<4, vector<2xf16>>) -> () { + %sg = alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3> + %i = constant 16 : index + %j = constant 16 : index + // expected-error @+1 {{expected identity layout map for destination memref}} + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mmafragment<4, vector<2xf16>>, memref<32x32xf16,#layout_map_col_major, 3> + return +} + +// ----- + +func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mmafragment<4, vector<2xf16>>) -> () { + %sg = alloca(){alignment = 32} : memref<32x32xf16, 5> + %i = constant 16 : index + %j = constant 16 : index + // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}} + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mmafragment<4, vector<2xf16>>, memref<32x32xf16, 5> + return +} + +// ----- + +func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mmafragment<5, vector<2xf16>>) -> () { + %sg = alloca(){alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + // expected-error @+1 {{operand should be of type !gpu.mmafragment<4xvector<2xf16>>}} + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mmafragment<5, vector<2xf16>>, memref<32x32xf16, 3> + return +} + +// ----- + +func @wmmaMmaOp_invalid_operand_shape(%A : !gpu.mmafragment<3, vector<2xf16>>, %B : !gpu.mmafragment<8, vector<2xf16>>, %C : !gpu.mmafragment<4, vector<2xf16>>) -> () { + // expected-error @+1 {{A and B must be of type !gpu.mmafragment<8xvector<2xf16>> and C must of type !gpu.mmafragment<4xvector<2xf16>>}} + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mmafragment<3, vector<2xf16>>, !gpu.mmafragment<8, vector<2xf16>>, !gpu.mmafragment<4, vector<2xf16>> -> !gpu.mmafragment<4, vector<2xf16>> + return +} diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -667,3 +667,102 @@ ^bb2(%1: i32, %2: i32): // pred: ^bb0 llvm.return } + +// ----- + +llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} + %0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 s}} + %0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 xf16>, %arg5: i32) { + // expected-error@+1 {{expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} + nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32 + llvm.return +} + +// ----- + +llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 xf16>, %arg5: i32) { + // expected-error@+1 {{expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} + nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, + %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, + %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { + // expected-error@+1 {{expected 20 s as operands}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, + %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, + %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { + // expected-error@+1 {{expected result type to be a struct of 4 s}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + llvm.return +} diff --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir --- a/mlir/test/Target/nvvmir.mlir +++ b/mlir/test/Target/nvvmir.mlir @@ -73,6 +73,43 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic +// in the LLVM NVPTX backend. +llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr, %arg1: i32) { + // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}}) + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic +// in the LLVM NVPTX backend. +llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 xf16>, %arg5: i32) { + // CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}}) + nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32 + llvm.return +} + +// The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic +// in the LLVM NVPTX backend. +llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, + %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, + %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { + // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) + %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {gpu.kernel} {