diff --git a/mlir/include/mlir/Dialect/GPU/GPUBase.td b/mlir/include/mlir/Dialect/GPU/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -57,6 +57,17 @@ GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">, BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">; +// Predicat to check if type is gpu::MMAMatrixType. +def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">; + +def GPU_MMAMatrix : DialectType< + GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">; + +class MMAMatrixOf allowedTypes> : + ContainerType, IsMMAMatrixTypePred, + "$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()", + "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">; + def GPU_AsyncOpInterface : OpInterface<"AsyncOpInterface"> { let description = [{ Interface for GPU operations that execute asynchronously on the device. @@ -102,4 +113,18 @@ ]; } +// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing +// the layouts of the operands supported by the ops that use this attribute. +def RowMajor: StrEnumAttrCase<"RowMajor", 0>; +def ColMajor: StrEnumAttrCase<"ColMajor", 1>; + +// Specifies a String enum Attribute for Warp wide matrix operations, +// representing the layout of respective operands. The layout later governs +// the lowerings to appropriate intrinsics. +def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", + [RowMajor, ColMajor]> { + let stringToSymbolFnName = "LayoutStrToEnum"; + let symbolToStringFnName = "EnumToLayoutStr"; +} + #endif // GPU_BASE diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -44,6 +44,122 @@ using Base::Base; }; +/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape +/// and type. +struct MMAMatrixStorageType : public TypeStorage { + MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes, + Type elementType, StringRef operand) + : dimShapes(dimShapes), numDims(numDims), elementType(elementType), + operand(operand) {} + + /// The hash key for uniquing. + using KeyTy = std::tuple, Type, StringRef>; + bool operator==(const KeyTy &key) const { + return key == KeyTy(getShape(), elementType, operand); + } + + /// Construction. + static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + ArrayRef shape = allocator.copyInto(std::get<0>(key)); + StringRef operand = allocator.copyInto(std::get<2>(key)); + + return new (allocator.allocate()) + MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key), + operand); + } + + ArrayRef getShape() const { + return ArrayRef(dimShapes, numDims); + } + + StringRef getOperand() const { return operand; } + + /// Reference to the shape of the MMA matrix. + const int64_t *dimShapes; + + /// Number of dimensions in the MMA matrix. + unsigned numDims; + + /// Element type of elements held in the MMA matrix. + Type elementType; + + /// MMA operand that this MMAMatrix holds. The general form of operation this + /// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This + /// field specifies which operand in the given equation is held by this type. + /// The valid values are "AOp", "BOp", "COp" and "DOp". + StringRef operand; +}; + +/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply +/// accumulate operations. MMAMatrices are taken as direct operands by these +/// operations and are also produced as results. These matrices are meant to +/// reside in the registers. A limited number of pointwise operations can be +/// performed on these matrices, i.e., operations which operate uniformly on +/// all the elements in the matrix and do not change the order of matrix +/// elements. The above conditions exist because the layout of matrix elements +/// inside the matrix is opaque i.e., the elements may be present in the +/// matrix in any order. The general usage of this type is shown as follows:- +/// +/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 : +/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +/// +/// The MMAMatrixType describes the shape of the matrix being loaded and the +/// operand being loaded too. The operand needs to be specified to aid the +/// lowering of this type to dialects such as NVVM where each workitem may +/// hold different amount of elements depending on the elementType of the +/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type +/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage +/// are:- +/// +/// %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, +/// "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32, +/// "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp"> +/// +/// +/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 +/// : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32> +// TODO: consider moving this to ODS. +class MMAMatrixType + : public Type::TypeBase { +public: + using Base::Base; + + /// Get MMAMatrixType and verify construction Invariants. + static MMAMatrixType get(ArrayRef shape, Type elementType, + StringRef operand); + + /// Get MMAMatrixType at a particular location and verify construction + /// Invariants. + static MMAMatrixType getChecked(function_ref emitError, + ArrayRef shape, Type elementType, + StringRef operand); + + /// Check if a type is valid a MMAMatrixType elementType. + static bool isValidElementType(Type elementType); + + /// Verify that shape and elementType are actually allowed for the + /// MMAMatrixType. + static LogicalResult verify(function_ref emitError, + ArrayRef shape, Type elementType, + StringRef operand); + + /// Get number of dims. + unsigned getNumDims() const; + + /// Get shape of the matrix. + ArrayRef getShape() const; + + /// Get elementType of a single element. + Type getElementType() const; + + /// The general form of operation this type supports is given by the equation + /// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the + /// given equation is held by this type. String returned can be one of"AOp", + /// "BOp", "COp" and "DOp". + StringRef getOperand() const; +}; + // Adds a `gpu.async.token` to the front of the argument list. void addAsyncDependency(Operation *op, Value token); diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -912,4 +912,122 @@ let verifier = [{ return ::verify(*this); }]; } +def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", + [MemoryEffects<[MemRead]>]>{ + + let summary = "GPU warp synchronous matrix load"; + + let description = [{ + The `gpu.subgroup_mma_load_matrix` operation loads a matrix collectively + using all the threads in a subgroup. + + This operation takes a memref as argument. It is the source matrix from which + data is to be loaded. The op returns a `!gpu.mma_matrix`. The source memref + can be in the global or shared memory space. The starting of the load address + is determined using indices provided. The matrix being loaded is specified in + the result type. This attribute is necessary because there exists a different + LLVM intrinsic for loading each operand, This is probably because all operands + need to be laid out in a specific/different way for the operation in the registers. + `leadDimension` attribute specifies the leading dimension of the source matrix. + + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and + `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + %0 = gpu.subgroup_mma_load_matrix src[%i,%j] : {leadDimension = 32 + : i32} : memref<32x32xf16, 3>, !gpu.mma_matrix<16x16xf16, "AOp"> + ``` + }]; + + let arguments = (ins Arg, "", [MemRead]>:$srcMemref, + Variadic:$indices, + IndexAttr:$leadDimension); + + let results = (outs GPU_MMAMatrix:$res); + + let assemblyFormat = [{ + $srcMemref`[`$indices`]` attr-dict `:` type($srcMemref) `->` type($res) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", + [MemoryEffects<[MemWrite]>]>{ + + let summary = "GPU warp synchronous matrix store"; + + let description = [{ + The `gpu.subgroup_mma_store_matrix` operation stores a matrix collectively + using all the threads in a subgroup. + + This operation takes a `!gpu.mma_matrix` and a memref as arguments. + `!gpu.mma_matrix` is the source which contains the data to be stored. + The destination can be in the global or shared memory space. The starting + of store address is determined using indices provided. The `leadDimension` + attribute specifies the leading dimension of the destination matrix. + + This op is meant to be used along with `gpu.subgroup_mma_load_matrix` and + `gpu.subgroup_mma_compute`. + + Example: + + ```mlir + gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} : + !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3> + ``` + }]; + + let arguments = (ins Arg>:$src, + Arg, "",[MemWrite]>:$dstMemref, + Variadic:$indices, + IndexAttr:$leadDimension); + + let assemblyFormat = [{ + $src`,` $dstMemref`[`$indices`]` attr-dict `:` type($src)`,` type($dstMemref) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + +def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{ + + let summary = "GPU warp synchronous matrix multiply accumulate"; + + let description = [{ + The `gpu.subgroup_mma_compute` operation performs a matrix-multiply accumulate(mma) + operation using all the threads in a subgroup. + + This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`, + `B` and `C`operands for the mma operation. The operation performed is represented + as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of + the operation held by the current thread. + + This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and + `gpu.subgroup_mma_load_matrix`. + + Example: + + ```mlir + %D = gpu.subgroup_mma_compute_matrix %A, %B, %C : + !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, + !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp"> + ``` + }]; + + let arguments = (ins Arg>:$opA, + Arg>:$opB, + Arg>:$opC); + + let results = (outs GPU_MMAMatrix:$res); + + let assemblyFormat = [{ + $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + #endif // GPU_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -151,4 +151,254 @@ let verifier = [{ return ::verify(*this); }]; } +// Base class for all the variants of WMMA loadOps that may be defined. +class NVVM_WMMALoadOp : NVVM_Op, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins Variadic:$args)> { + + let summary = "Warp synchronous matrix load"; + + string baseDescription = [{"The `nvvm.wmma.m*n*k*.load.[a, b, c]` operation" + "loads a matrix collectively using all the threads in a warp." + + "The operation takes two arguments, the address from where the matrix" + "elements are to be loaded from and a stride. The stride argument" + "represents the leading dimension of the source matrix. The address and" + "the stride are required to be the same across all threads in the warp." + "Each thread in a warp holds a certain number of elements. The Op returns" + "a LLVMStruct which holds the elements of the matrix held by this thread." + + "This op is meant to be used along with `nvvm.wmma.m*n*k*.store` and" + "`nvvm.wmma.m*n*k*.mma`."}]; + + let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; +} + +def NVVM_WMMALoadAM16N16K16Op : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.a.f16.row.stride">{ + + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.a %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMALoadBM16N16K16Op : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.b.f16.row.stride">{ + + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.b %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, + vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMALoadCF16M16N16K16Op : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f16.row.stride">{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMALoadCF32M16N16K16Op : + NVVM_WMMALoadOp<"wmma.m16n16k16.load.c.f32.row.stride">{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %2 = nvvm.wmma.m16n16k16.load.c.f32.row.stride %0, %1 : !llvm.ptr, !llvm.i32 -> + !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +// Base class for all the variants of WMMA storeOps that may be defined. +class NVVM_WMMAStoreOp : NVVM_Op, + Arguments<(ins Variadic:$args)>{ + let summary = "Warp synchronous matrix store"; + + string baseDescription = [{ + The `nvvm.wmma.m*n*k*.store` operation stores a matrix collectively using + all the threads in a warp. + + The operation takes as arguments the address to where the matrix elements are + to be stored, a stride and the elements to store, held by the current thread. + The stride argument represents the leading dimension of the destination matrix. + The address and the stride are required to be the same across all threads in the + warp. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and + `nvvm.wmma.m16n16k16.mma`. + }]; + + let assemblyFormat = "$args attr-dict `:` type($args)"; +} + +def NVVM_WMMAStoreF16M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f16.row.stride"> { + string llvmBuilder = [{ + createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + nvvm.wmma.m16n16k16.stored.f16.row.stride %0, %1, %2, %3, %4, %5, %6 : !llvm.ptr, + !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)>, !llvm.i32 + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMAStoreF32M16N16K16Op : NVVM_WMMAStoreOp<"wmma.m16n16k16.store.d.f32.row.stride"> { + string llvmBuilder = [{ + createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + nvvm.wmma.m16n16k16.store.d.f32.row.stride %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, + %10 : !llvm.ptr, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>, + !llvm.i32 + ``` + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +// Base class for all the variants of WMMA mmaOps that may be defined. +class NVVM_WMMAMmaOp : NVVM_Op, + Results<(outs LLVM_AnyStruct:$res)>, + Arguments<(ins Variadic:$args)>{ + let summary = "Warp synchronous matrix-multiply accumulate using tensor cores."; + + string baseDescription = [{ + The `nvvm.wmma.m*n*k*.mma` operation performs a matrix-multiply accumulate + (mma) operation using all the threads in a warp. + + The operation performed is represented as `D = A * B + C`. The operation takes + as arguments the elements of the matrices `A`, `B`, `C` and `D`, held by the + current thread. The op returns a LLVM struct which holds a part of the result + held by the current thread. + + This op is meant to be used along with `nvvm.wmma.m16n16k16.load` and `nvvm.wmma. + m16n16k16.store`. + }]; +} + +def NVVM_WMMAMmaF16F16M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f16.f16">{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %0, %1, %2, %3, %4, %5, %6, %7, %8, + %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19 : vector<2xf16> -> !llvm.struct + <(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + ``` + }]; + + let parser = [{ + return parseWMMAMmaF16F16M16N16K16Op(parser, result); + }]; + + let printer = [{ + printWMMAMmaF16F16M16N16K16Op(p, *this); + }]; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + +def NVVM_WMMAMmaF32F32M16N16K16Op : NVVM_WMMAMmaOp<"wmma.m16n16k16.mma.row.row.f32.f32">{ + string llvmBuilder = [{ + $res = createNvvmIntrinsicCall( + builder, llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32, $args); + }]; + + string opDescription = [{ + Example: + + ```mlir + %24 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %0, %1, %2, %3, %4, %5, %6, %7, %8 + %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23 : + (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, + vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, + vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, + vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, + f32, f32, f32, f32, f32, f32, f32)> + ``` + }]; + + let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; + + let description = !strconcat(baseDescription, opDescription); + + let verifier = [{ return ::verify(*this); }]; +} + #endif // NVVMIR_OPS diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -257,6 +257,14 @@ llvm::Intrinsic::ID intrinsic, ArrayRef args = {}, ArrayRef tys = {}); + +/// Creates a call to an LLVM IR intrinsic function with the given arguments +/// for NVVM WMMA ops. Handles cases where the intrinsic name is overloaded +/// using the types of arguments supplied. Selects the correct intrinsic +/// by inspecting the argument types. +llvm::Value *createNvvmIntrinsicCall(llvm::IRBuilderBase &builder, + llvm::Intrinsic::ID intrinsic, + ArrayRef args = {}); } // namespace detail } // namespace LLVM diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -28,10 +28,70 @@ using namespace mlir; using namespace mlir::gpu; +//===----------------------------------------------------------------------===// +// MMAMatrixType +//===----------------------------------------------------------------------===// + +MMAMatrixType MMAMatrixType::get(ArrayRef shape, Type elementType, + StringRef operand) { + return Base::get(elementType.getContext(), shape, elementType, operand); +} + +MMAMatrixType +MMAMatrixType::getChecked(function_ref emitError, + ArrayRef shape, Type elementType, + StringRef operand) { + return Base::getChecked(emitError, elementType.getContext(), shape, + elementType, operand); +} + +unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } + +ArrayRef MMAMatrixType::getShape() const { + return getImpl()->getShape(); +} + +Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } + +StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } + +bool MMAMatrixType::isValidElementType(Type elementType) { + return elementType.isF16() || elementType.isF32(); +} + +LogicalResult +MMAMatrixType::verify(function_ref emitError, + ArrayRef shape, Type elementType, + StringRef operand) { + if (!operand.equals("AOp") && !operand.equals("BOp") && + !operand.equals("COp") && !operand.equals("DOp")) + return emitError() << "operand expected to be one of AOp, BOp, COp or DOp"; + + if (shape.size() != 2) + return emitError() << "MMAMatrixType must have exactly two dimensions"; + + if (!MMAMatrixType::isValidElementType(elementType)) + return emitError() << "MMAMatrixType elements must be F16 or F32"; + + return success(); +} + //===----------------------------------------------------------------------===// // GPUDialect //===----------------------------------------------------------------------===// +/// GPU memory space identifiers. +enum GPUMemorySpace { + /// Generic memory space identifier. + kGenericMemorySpace = 0, + + /// Global memory space identifier. + kGlobalMemorySpace = 1, + + /// Shared memory space identifier. + kSharedMemorySpace = 3 +}; + bool GPUDialect::isKernel(Operation *op) { UnitAttr isKernelAttr = op->getAttrOfType(getKernelFuncAttrName()); return static_cast(isKernelAttr); @@ -39,6 +99,7 @@ void GPUDialect::initialize() { addTypes(); + addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/GPU/GPUOps.cpp.inc" @@ -56,6 +117,38 @@ if (keyword == "async.token") return AsyncTokenType::get(context); + if (keyword == "mma_matrix") { + llvm::SMLoc beginLoc = parser.getNameLoc(); + + // Parse '<'. + if (parser.parseLess()) + return nullptr; + + // Parse the size and elementType. + SmallVector shape; + Type elementType; + if (parser.parseDimensionList(shape, /*allowDynamic=*/false) || + parser.parseType(elementType)) + return nullptr; + + // Parse ',' + if (parser.parseComma()) + return nullptr; + + // Parse operand. + StringRef operand; + if (failed(parser.parseOptionalString(&operand))) + return nullptr; + + // Parse '>'. + if (parser.parseGreater()) + return nullptr; + + return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( + parser.getEncodedSourceLoc(beginLoc)), + shape, elementType, operand); + } + parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); return Type(); } @@ -63,6 +156,14 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "async.token"; }) + .Case([&](MMAMatrixType fragTy) { + os << "mma_matrix<"; + auto shape = fragTy.getShape(); + for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) + os << *dim << 'x'; + os << shape.back() << 'x' << fragTy.getElementType(); + os << ", \"" << fragTy.getOperand() << "\"" << '>'; + }) .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); } @@ -138,7 +239,8 @@ return walkResult.wasInterrupted() ? failure() : success(); } -template static LogicalResult verifyIndexOp(T op) { +template +static LogicalResult verifyIndexOp(T op) { auto dimension = op.dimension(); if (dimension != "x" && dimension != "y" && dimension != "z") return op.emitError("dimension \"") << dimension << "\" is invalid"; @@ -885,6 +987,95 @@ printer << "]"; } +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaLoadMatrixOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { + auto srcType = op.srcMemref().getType(); + auto resType = op.res().getType(); + auto resMatrixType = resType.cast(); + auto operand = resMatrixType.getOperand(); + auto srcMemrefType = srcType.cast(); + auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); + + if (!srcMemrefType.getAffineMaps().empty() && + !srcMemrefType.getAffineMaps().front().isIdentity()) + return op.emitError("expected identity layout map for source memref"); + + if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace && + srcMemSpace != kGlobalMemorySpace) + return op.emitError( + "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " + "kGlobalMemorySpace only allowed"); + + if (!operand.equals("AOp") && !operand.equals("BOp") && + !operand.equals("COp")) + return op.emitError("only AOp, BOp and COp can be loaded"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaStoreMatrixOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaStoreMatrixOp op) { + auto srcType = op.src().getType(); + auto dstType = op.dstMemref().getType(); + auto srcMatrixType = srcType.cast(); + auto dstMemrefType = dstType.cast(); + auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt(); + + if (!dstMemrefType.getAffineMaps().empty() && + !dstMemrefType.getAffineMaps().front().isIdentity()) + return op.emitError("expected identity layout map for destination memref"); + + if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace && + dstMemSpace != kGlobalMemorySpace) + return op.emitError( + "destination memorySpace of kGenericMemorySpace, " + "kGlobalMemorySpace or kSharedMemorySpace only allowed"); + + if (!srcMatrixType.getOperand().equals("DOp")) + return op.emitError( + "expected the operand matrix being stored to have 'DOp' operand type"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupMmaComputeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(SubgroupMmaComputeOp op) { + enum OperandMap { A, B, C }; + SmallVector opTypes; + + auto populateOpInfo = [&opTypes, &op]() { + opTypes.push_back(op.opA().getType().cast()); + opTypes.push_back(op.opB().getType().cast()); + opTypes.push_back(op.opC().getType().cast()); + }; + populateOpInfo(); + + if (!opTypes[A].getOperand().equals("AOp") || + !opTypes[B].getOperand().equals("BOp") || + !opTypes[C].getOperand().equals("COp")) + return op.emitError("operands must be in the order AOp, BOp, COp"); + + ArrayRef aShape, bShape, cShape; + aShape = opTypes[A].getShape(); + bShape = opTypes[B].getShape(); + cShape = opTypes[C].getShape(); + + if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || + bShape[1] != cShape[1]) + return op.emitError("operand shapes do not satisfy matmul constraints"); + + return success(); +} + #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -94,12 +94,12 @@ auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - SmallVector operand_types(op.getOperandTypes().begin(), - op.getOperandTypes().end()); - if (operand_types != SmallVector(8, f16x2Ty) && - operand_types != SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty}) { + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector(8, f16x2Ty) && + operandTypes != SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty}) { return op.emitOpError( "expected operands to be 4 s followed by either " "4 s or 8 floats"); @@ -120,9 +120,9 @@ "\"row\" or \"col\""); } - if (operand_types == SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty} && + if (operandTypes == SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty} && op.getType() == f32x8StructTy && alayout.getValue() == "row" && blayout.getValue() == "col") { return success(); @@ -130,6 +130,205 @@ return op.emitOpError("unimplemented mma.sync variant"); } +template +static LogicalResult verifyWMMALoadOp(T op, StringRef operand) { + MLIRContext *context = op.getContext(); + auto i32Ty = IntegerType::get(context, 32); + auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); + auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); + auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); + auto f16Ty = FloatType::getF16(context); + auto f32Ty = FloatType::getF32(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + auto f16x2x8StructTy = LLVM::LLVMStructType::getLiteral( + context, + {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); + + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector{i32Ptr1Ty, i32Ty} && + operandTypes != SmallVector{i32Ptr3Ty, i32Ty} && + operandTypes != SmallVector{i32Ptr0Ty, i32Ty}) { + return op.emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3 followed by ldm of the source"); + } + + if (operand.equals("AOp") || operand.equals("BOp")) { + if (op.getType() != f16x2x8StructTy) { + return op.emitOpError("expected result type of loadAOp and loadBOp to be " + "a struct of 8 s"); + } + } else if (operand.equals("COp")) { + if (op.getType() != f16x2x4StructTy && op.getType() != f32x8StructTy) { + return op.emitOpError("expected result type of loadCOp to be a struct of " + "4 s or 8 f32s"); + } + } + + return success(); +} + +static LogicalResult verify(WMMALoadAM16N16K16Op op) { + return verifyWMMALoadOp(op, "AOp"); +} + +static LogicalResult verify(WMMALoadBM16N16K16Op op) { + return verifyWMMALoadOp(op, "BOp"); +} + +static LogicalResult verify(WMMALoadCF16M16N16K16Op op) { + return verifyWMMALoadOp(op, "COp"); +} + +static LogicalResult verify(WMMALoadCF32M16N16K16Op op) { + return verifyWMMALoadOp(op, "COp"); +} + +template +static bool verifyWMMAStoreOp(T op, SmallVector &containedElems) { + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes == containedElems) + return true; + + return false; +} + +static LogicalResult verify(WMMAStoreF16M16N16K16Op op) { + MLIRContext *context = op.getContext(); + auto i32Ty = IntegerType::get(context, 32); + auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); + auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); + auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); + auto f16Ty = FloatType::getF16(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + SmallVector type1{i32Ptr1Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty}; + SmallVector type0{i32Ptr0Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty}; + SmallVector type3{i32Ptr3Ty, f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, i32Ty}; + if (verifyWMMAStoreOp(op, type1) || verifyWMMAStoreOp(op, type0) || + verifyWMMAStoreOp(op, type3)) + return success(); + + return op.emitOpError("expected operands to be a source pointer in memory" + "space 0, 1, 3 followed by ldm of the source"); +} + +static LogicalResult verify(WMMAStoreF32M16N16K16Op op) { + MLIRContext *context = op.getContext(); + auto i32Ty = IntegerType::get(context, 32); + auto i32Ptr1Ty = LLVM::LLVMPointerType::get(i32Ty, 1); + auto i32Ptr3Ty = LLVM::LLVMPointerType::get(i32Ty, 3); + auto i32Ptr0Ty = LLVM::LLVMPointerType::get(i32Ty, 0); + auto f32Ty = FloatType::getF32(context); + + SmallVector type1{i32Ptr1Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, i32Ty}; + SmallVector type0{i32Ptr0Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, i32Ty}; + SmallVector type3{i32Ptr3Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, i32Ty}; + if (verifyWMMAStoreOp(op, type0) || verifyWMMAStoreOp(op, type1) || + verifyWMMAStoreOp(op, type3)) + return success(); + + return op.emitOpError("expected operands to be a source pointer in memory" + "space 0, 1, 3 followed by ldm of the source"); +} + +static LogicalResult verify(WMMAMmaF16F16M16N16K16Op op) { + MLIRContext *context = op.getContext(); + auto f16Ty = FloatType::getF16(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( + context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); + + SmallVector operandTypes(op.getOperandTypes().begin(), + op.getOperandTypes().end()); + if (operandTypes != SmallVector(20, f16x2Ty)) + return op.emitOpError("expected 20 s as operands"); + + if (op.getResult().getType() != f16x2x4StructTy) + return op.emitOpError("expected result type to be a struct of 4 s"); + + return success(); +} + +static LogicalResult parseWMMAMmaF16F16M16N16K16Op(OpAsmParser &parser, + OperationState &result) { + SmallVector operands; + ::llvm::SMLoc operandsLoc; + Type operandType; + Type resType; + + operandsLoc = parser.getCurrentLocation(); + if (parser.parseOperandList(operands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(operandType) || parser.parseArrow()) + return failure(); + + unsigned numOperands = operands.size(); + SmallVector operandTypes(numOperands, operandType); + if (parser.parseType(resType)) + return failure(); + result.addTypes(resType); + if (parser.resolveOperands(operands, operandTypes, operandsLoc, + result.operands)) + return failure(); + return success(); +} + +static void printWMMAMmaF16F16M16N16K16Op(OpAsmPrinter &p, + WMMAMmaF16F16M16N16K16Op &op) { + p << op.getOperationName(); + p << ' '; + p << op.args(); + p.printOptionalAttrDict(op->getAttrs(), {}); + p << " : "; + p << op->getOperand(0).getType(); + p << ' ' << "->"; + p << ' '; + p << ::llvm::ArrayRef<::mlir::Type>(op.res().getType()); +} + +static LogicalResult verify(WMMAMmaF32F32M16N16K16Op op) { + unsigned numABOperands = 16; + unsigned numCOperands = 8; + MLIRContext *context = op.getContext(); + auto f16Ty = FloatType::getF16(context); + auto f32Ty = FloatType::getF32(context); + auto f16x2Ty = VectorType::get(2, f16Ty); + auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( + context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); + + SmallVector abOpTypes; + SmallVector bOpTypes; + SmallVector cOpTypes; + + for (auto operand : op->getOperands().take_front(numABOperands)) { + abOpTypes.push_back(operand.getType()); + } + + for (auto operand : + op->getOperands().drop_front(numABOperands).take_front(numCOperands)) { + cOpTypes.push_back(operand.getType()); + } + + if (abOpTypes != SmallVector(16, f16x2Ty)) + return op.emitOpError("expected 16 s for `a` and `b` operand"); + + if (cOpTypes != SmallVector(8, f32Ty)) + return op.emitOpError("expected 8 f32s for `c` operand"); + + if (op.getResult().getType() != f32x8StructTy) + return op.emitOpError("expected result type to be a struct of 8 f32s"); + + return success(); +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -141,7 +340,8 @@ #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" >(); - // Support unknown operations because not all NVVM operations are registered. + // Support unknown operations because not all NVVM operations are + // registered. allowUnknownOperations(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -22,6 +22,7 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; +using mlir::LLVM::detail::createNvvmIntrinsicCall; static llvm::Intrinsic::ID getShflBflyIntrinsicId(llvm::Type *resultType, bool withPredicate) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" @@ -300,6 +301,29 @@ return builder.CreateCall(fn, args); } +llvm::Value * +mlir::LLVM::detail::createNvvmIntrinsicCall(llvm::IRBuilderBase &builder, + llvm::Intrinsic::ID intrinsic, + ArrayRef args) { + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Function *fn; + if (llvm::Intrinsic::isOverloaded(intrinsic)) { + if (intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f16_f16 && + intrinsic != llvm::Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_f32_f32) { + // NVVM load and store instrinsic names are overloaded on the + // source/destination pointer type. Pointer is the first argument in the + // corresponding NVVM Op. + fn = llvm::Intrinsic::getDeclaration(module, intrinsic, + {args[0]->getType()}); + } else { + fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {}); + } + } else { + fn = llvm::Intrinsic::getDeclaration(module, intrinsic); + } + return builder.CreateCall(fn, args); +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LogicalResult diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -458,3 +458,116 @@ // expected-error @+1 {{'gpu.memcpy' op arguments have incompatible shape}} gpu.memcpy %dst, %src : memref<7xf32>, memref<9xf32> } + +// ----- + +func @mmamatrix_invalid_shape(){ + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + // expected-error @+1 {{MMAMatrixType must have exactly two dimensions}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16x16xf16, "AOp"> + return +} + +// ----- + +func @mmamatrix_operand_type(){ + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + // expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp"> + return +} + +// ----- + +func @mmamatrix_invalid_element_type(){ + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp"> + return +} + +// ----- + +#layout_map_col_major = affine_map<(i, j) -> (j, i)> + +func @mmaLoadOp_identity_layout(){ + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3> + %i = constant 16 : index + // expected-error @+1 {{expected identity layout map for source memref}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, #layout_map_col_major, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + return +} + +// ----- + +func @mmaLoadOp_invalid_mem_space(){ + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5> + %i = constant 16 : index + // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp"> + return +} + +// ----- + +func @mmaLoadOp_operand_type(){ + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + // expected-error @+1 {{only AOp, BOp and COp can be loaded}} + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp"> + return +} + +// ----- + +#layout_map_col_major = affine_map<(i, j) -> (j, i)> + +func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () { + %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3> + %i = constant 16 : index + %j = constant 16 : index + // expected-error @+1 {{expected identity layout map for destination memref}} + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3> + return +} + +// ----- + +func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () { + %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5> + %i = constant 16 : index + %j = constant 16 : index + // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}} + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5> + return +} + +// ----- + +func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () { + %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + // expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}} + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3> + return +} + +// ----- + +func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () { + // expected-error @+1 {{operands must be in the order AOp, BOp, COp}} + %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp"> + return +} + +// ----- + +func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () { + // expected-error @+1 {{operand shapes do not satisfy matmul constraints}} + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp"> + return +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -194,4 +194,15 @@ %1 = gpu.memcpy async [%0] %dst, %src : memref<3x7xf32>, memref<3x7xf32, 1> return } + + func @mmamatrix_valid_element_type(){ + // CHECK-LABEL: func @mmamatrix_valid_element_type + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + // CHECK: %[[wg:.*]] = memref.alloca() + %i = constant 16 : index + // CHECK: %[[i:.*]] = constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + // CHECK: gpu.subgroup_mma_load_matrix %[[wg]][%[[i]], %[[i]]] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + return + } } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -843,3 +843,162 @@ llvm.return } } + +// ----- + +llvm.func @wmmaLoadOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_missing_ldm(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected operands to be a source pointer in memory space 0, 1, 3 followed by ldm of the source}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0: (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_AOp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{nvvm.wmma.m16n16k16.load.a.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_BOp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.b.f16.row.stride' op expected result type of loadAOp and loadBOp to be a struct of 8 s}} + %0 = nvvm.wmma.m16n16k16.load.b.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaLoadOp_invalid_COp(%arg0: !llvm.ptr, %arg1: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.load.c.f16.row.stride' op expected result type of loadCOp to be a struct of 4 s or 8 f32s}} + %0 = nvvm.wmma.m16n16k16.load.c.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// ----- + +llvm.func @wmmaStoreOp_invalid_mem_space(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 xf16>, %arg5: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}} + nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32 + llvm.return +} + +// ----- + +llvm.func @wmmaStoreOp_invalid_missing_ldm(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 xf16>, %arg5: i32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.store.d.f16.row.stride' op expected operands to be a source pointer in memoryspace 0, 1, 3 followed by ldm of the source}} + nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op_invalid_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, + %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, + %arg18: vector<2 x f16>) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f16.f16' op expected 20 s as operands}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op_results(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, + %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, + %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { + // expected-error@+1 {{expected result type to be a struct of 4 s}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op_invalid_ab_operands(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: f32, + %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32, + %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 16 s for `a` and `b` operand}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op_invalid_c_operand(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2xf16>, + %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32, + %arg20: f32, %arg21: f32, %arg22: f32, %arg23: vector<2xf16>) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected 8 f32s for `c` operand}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> + llvm.return +} + +// ----- + +llvm.func @gpu_wmma_mma_op_invalid_result(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2xf16>, + %arg16: f32, %arg17: f32, %arg18: f32, %arg19: f32, + %arg20: f32, %arg21: f32, %arg22: f32, %arg23: f32) { + // expected-error@+1 {{'nvvm.wmma.m16n16k16.mma.row.row.f32.f32' op expected result type to be a struct of 8 f32s}} + %0 = nvvm.wmma.m16n16k16.mma.row.row.f32.f32 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23 : (vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, f32, f32, f32, f32, f32, f32, f32, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, vector<2xf16>)> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -73,6 +73,43 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// The test below checks the correct mapping of the nvvm.wmma.*.load.* op to the correct intrinsic +// in the LLVM NVPTX backend. +llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr, %arg1: i32) { + // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}}) + %0 = nvvm.wmma.m16n16k16.load.a.f16.row.stride %arg0, %arg1 : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + + llvm.return +} + +// The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic +// in the LLVM NVPTX backend. +llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 xf16>, %arg5: i32) { + // CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}}) + nvvm.wmma.m16n16k16.store.d.f16.row.stride %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, i32 + llvm.return +} + +// The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic +// in the LLVM NVPTX backend. +llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, + %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, + %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, + %arg6: vector<2 x f16>, %arg7: vector<2 x f16>, + %arg8: vector<2 x f16>, %arg9: vector<2 x f16>, + %arg10: vector<2 x f16>, %arg11: vector<2 x f16>, + %arg12: vector<2 x f16>, %arg13: vector<2 x f16>, + %arg14: vector<2 x f16>, %arg15: vector<2 x f16>, + %arg16: vector<2 x f16>, %arg17: vector<2 x f16>, + %arg18: vector<2 x f16>, %arg19: vector<2 x f16>) { + // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) + %0 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : vector<2 x f16> -> !llvm.struct<(vector<2 x f16>, vector<2 x f16>, vector<2 x f16>, vector<2 x f16>)> + + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {nvvm.kernel} {