Index: mlir/include/mlir/Dialect/GPU/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/GPU/CMakeLists.txt +++ mlir/include/mlir/Dialect/GPU/CMakeLists.txt @@ -22,4 +22,9 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix GPU) add_public_tablegen_target(MLIRGPUPassIncGen) +set(LLVM_TARGET_DEFINITIONS GPUOps.td) +mlir_tablegen(GPUOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(GPUOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRGPUOpsEnumsGen) + add_mlir_doc(Passes GPUPasses ./ -gen-pass-doc) Index: mlir/include/mlir/Dialect/GPU/GPUBase.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUBase.td +++ mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -113,18 +113,4 @@ ]; } -// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing -// the layouts of the operands supported by the ops that use this attribute. -def RowMajor: StrEnumAttrCase<"RowMajor", 0>; -def ColMajor: StrEnumAttrCase<"ColMajor", 1>; - -// Specifies a String enum Attribute for Warp wide matrix operations, -// representing the layout of respective operands. The layout later governs -// the lowerings to appropriate intrinsics. -def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", - [RowMajor, ColMajor]> { - let stringToSymbolFnName = "LayoutStrToEnum"; - let symbolToStringFnName = "EnumToLayoutStr"; -} - #endif // GPU_BASE Index: mlir/include/mlir/Dialect/GPU/GPUDialect.h =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -24,11 +24,14 @@ #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc" + namespace mlir { class FuncOp; namespace gpu { +struct MMAMatrixStorageType; /// Utility class for the GPU dialect to represent triples of `Value`s /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation. struct KernelDim3 { @@ -44,53 +47,6 @@ using Base::Base; }; -/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape -/// and type. -struct MMAMatrixStorageType : public TypeStorage { - MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes, - Type elementType, StringRef operand) - : dimShapes(dimShapes), numDims(numDims), elementType(elementType), - operand(operand) {} - - /// The hash key for uniquing. - using KeyTy = std::tuple, Type, StringRef>; - bool operator==(const KeyTy &key) const { - return key == KeyTy(getShape(), elementType, operand); - } - - /// Construction. - static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator, - const KeyTy &key) { - ArrayRef shape = allocator.copyInto(std::get<0>(key)); - StringRef operand = allocator.copyInto(std::get<2>(key)); - - return new (allocator.allocate()) - MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key), - operand); - } - - ArrayRef getShape() const { - return ArrayRef(dimShapes, numDims); - } - - StringRef getOperand() const { return operand; } - - /// Reference to the shape of the MMA matrix. - const int64_t *dimShapes; - - /// Number of dimensions in the MMA matrix. - unsigned numDims; - - /// Element type of elements held in the MMA matrix. - Type elementType; - - /// MMA operand that this MMAMatrix holds. The general form of operation this - /// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This - /// field specifies which operand in the given equation is held by this type. - /// The valid values are "AOp", "BOp", "COp" and "DOp". - StringRef operand; -}; - /// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply /// accumulate operations. MMAMatrices are taken as direct operands by these /// operations and are also produced as results. These matrices are meant to @@ -126,14 +82,14 @@ using Base::Base; /// Get MMAMatrixType and verify construction Invariants. - static MMAMatrixType get(ArrayRef shape, Type elementType, - StringRef operand); + static MMAMatrixType get(int64_t row, int64_t coloumn, Type elementType, + MMAFragType fragType); /// Get MMAMatrixType at a particular location and verify construction /// Invariants. static MMAMatrixType getChecked(function_ref emitError, - ArrayRef shape, Type elementType, - StringRef operand); + int64_t row, int64_t coloumn, + Type elementType, MMAFragType fragType); /// Check if a type is valid a MMAMatrixType elementType. static bool isValidElementType(Type elementType); @@ -141,14 +97,12 @@ /// Verify that shape and elementType are actually allowed for the /// MMAMatrixType. static LogicalResult verify(function_ref emitError, - ArrayRef shape, Type elementType, - StringRef operand); - - /// Get number of dims. - unsigned getNumDims() const; + int64_t row, int64_t column, Type elementType, + MMAFragType mmaFragType); /// Get shape of the matrix. - ArrayRef getShape() const; + int64_t getRow() const; + int64_t getColumn() const; /// Get elementType of a single element. Type getElementType() const; @@ -157,7 +111,7 @@ /// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the /// given equation is held by this type. String returned can be one of"AOp", /// "BOp", "COp" and "DOp". - StringRef getOperand() const; + MMAFragType getMMAFragType() const; }; // Adds a `gpu.async.token` to the front of the argument list. Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -572,13 +572,13 @@ } // add, mul mirror the XLA ComparisonDirection enum. -def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">; -def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">; -def GPU_AllReduceOpMax : StrEnumAttrCase<"max">; -def GPU_AllReduceOpMin : StrEnumAttrCase<"min">; -def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">; -def GPU_AllReduceOpOr : StrEnumAttrCase<"or">; -def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">; +def GPU_AllReduceOpAdd : StrEnumAttrCase<"ADD", -1, "add">; +def GPU_AllReduceOpAnd : StrEnumAttrCase<"AND", -1, "and">; +def GPU_AllReduceOpMax : StrEnumAttrCase<"MAX", -1, "max">; +def GPU_AllReduceOpMin : StrEnumAttrCase<"MIN", -1, "min">; +def GPU_AllReduceOpMul : StrEnumAttrCase<"MUL", -1, "mul">; +def GPU_AllReduceOpOr : StrEnumAttrCase<"OR", -1, "or">; +def GPU_AllReduceOpXor : StrEnumAttrCase<"XOR", -1, "xor">; def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr", "built-in reduction operations supported by gpu.allreduce.", @@ -625,7 +625,7 @@ let verifier = [{ return ::verifyAllReduce(*this); }]; } -def GPU_ShuffleOpXor : StrEnumAttrCase<"xor">; +def GPU_ShuffleOpXor : StrEnumAttrCase<"XOR", -1, "xor">; def GPU_ShuffleModeAttr : StrEnumAttr<"ShuffleModeAttr", "Indexing modes supported by gpu.shuffle.", @@ -902,6 +902,35 @@ let verifier = [{ return ::verify(*this); }]; } +//===----------------------------------------------------------------------===// +// MMA Dialect operations. +//===----------------------------------------------------------------------===// + +// Cases of the String enum Attribute for SubgroupMmaOpLayout, representing +// the layouts of the operands supported by the ops that use this attribute. +def RowMajor: StrEnumAttrCase<"RowMajor", 0>; +def ColMajor: StrEnumAttrCase<"ColMajor", 1>; + +// Specifies a String enum Attribute for Warp wide matrix operations, +// representing the layout of respective operands. The layout later governs +// the lowerings to appropriate intrinsics. +def SubgroupMmaOpLayout: StrEnumAttr<"Layout", "Specifies whether op is row/col major", + [RowMajor, ColMajor]> { + let stringToSymbolFnName = "LayoutStrToEnum"; + let symbolToStringFnName = "EnumToLayoutStr"; +} + + +def MMAFragTypeAOp : I32EnumAttrCase<"AOp", 0>; +def MMAFragTypeBOp : I32EnumAttrCase<"BOp", 1>; +def MMAFragTypeCOp : I32EnumAttrCase<"COp", 2>; +def MMAFragTypeDOp : I32EnumAttrCase<"DOp", 3>; + +def MMAFragTypeEnum: IntEnumAttr { + let underlyingType = "uint8_t"; +} + def GPU_SubgroupMmaLoadMatrixOp : GPU_Op<"subgroup_mma_load_matrix", [MemoryEffects<[MemRead]>]>{ Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -1110,11 +1110,12 @@ string str = strVal; } -// An enum attribute case stored with StringAttr. -class StrEnumAttrCase : - EnumAttrCaseInfo, +// An enum attribute case stored with StringAttr. By default, the string +// representation is the same as the C++ symbol name. +class StrEnumAttrCase : + EnumAttrCaseInfo, StringBasedAttr< - CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">, + CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">, "case " # sym>; // An enum attribute case stored with IntegerAttr, which has an integer value, Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -130,27 +130,27 @@ converter.addConversion([&](gpu::MMAMatrixType type) -> Type { // The number of items in structToReturn are dependent on the the dataType // and the MMA operand that this operation is associated with. - llvm::DenseMap numElemsPerThreadF16, + llvm::DenseMap numElemsPerThreadF16, numElemsPerThreadF32; - numElemsPerThreadF16["AOp"] = 8; - numElemsPerThreadF16["BOp"] = 8; - numElemsPerThreadF16["COp"] = 4; - numElemsPerThreadF16["DOp"] = 4; - numElemsPerThreadF32["AOp"] = 8; - numElemsPerThreadF32["BOp"] = 8; - numElemsPerThreadF32["COp"] = 8; - numElemsPerThreadF32["DOp"] = 8; + numElemsPerThreadF16[MMAFragType::AOp] = 8; + numElemsPerThreadF16[MMAFragType::BOp] = 8; + numElemsPerThreadF16[MMAFragType::COp] = 4; + numElemsPerThreadF16[MMAFragType::DOp] = 4; + numElemsPerThreadF32[MMAFragType::AOp] = 8; + numElemsPerThreadF32[MMAFragType::BOp] = 8; + numElemsPerThreadF32[MMAFragType::COp] = 8; + numElemsPerThreadF32[MMAFragType::DOp] = 8; Type structToReturn; if (type.getElementType().isF16()) { // Number of f16's in 32-bit. unsigned vecSize = 2; Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext())); - unsigned size = numElemsPerThreadF16[type.getOperand()]; + unsigned size = numElemsPerThreadF16[type.getMMAFragType()]; SmallVector elements(size, vec); structToReturn = LLVM::LLVMStructType::getLiteral(&getContext(), elements); } else if (type.getElementType().isF32()) { - unsigned size = numElemsPerThreadF32[type.getOperand()]; + unsigned size = numElemsPerThreadF32[type.getMMAFragType()]; SmallVector elements(size, FloatType::getF32(&getContext())); structToReturn = LLVM::LLVMStructType::getLiteral(&getContext(), elements); Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -163,11 +163,10 @@ // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType retType = subgroupMmaLoadMatrixOp.res().getType().cast(); - ArrayRef retTypeShape = retType.getShape(); Type resType; - StringRef operandStr = retType.getOperand(); - if (operandStr.equals("AOp") || operandStr.equals("BOp")) { + MMAFragType fragType = retType.getMMAFragType(); + if (fragType == MMAFragType::AOp || fragType == MMAFragType::BOp) { resType = fragArrayABTy; } else { if (srcMemrefType.getElementType().isF16()) @@ -180,8 +179,8 @@ // Create nvvm.mma_load op according to the operand types. SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); - if (operandStr.equals("AOp")) { - if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + if (fragType == MMAFragType::AOp) { + if (retType.getRow() == 16 && retType.getColumn() == 16) { NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp = rewriter.create(loc, resType, loadOpOperands); @@ -189,8 +188,8 @@ } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } - } else if (operandStr.equals("BOp")) { - if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + } else if (fragType == MMAFragType::BOp) { + if (retType.getRow() == 16 && retType.getColumn() == 16) { NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp = rewriter.create(loc, resType, loadOpOperands); @@ -199,7 +198,7 @@ return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } } else { - if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + if (retType.getRow() == 16 && retType.getColumn() == 16) { if (srcMemrefType.getElementType().isF16()) { NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp = rewriter.create(loc, resType, @@ -298,7 +297,8 @@ // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType srcType = subgroupMmaStoreMatrixOp.src().getType().cast(); - ArrayRef srcTypeShape = srcType.getShape(); + int64_t numRow = srcType.getRow(); + int64_t numColumn = srcType.getColumn(); // Unpack the results from the source. if (subgroupMmaStoreMatrixOp.src() @@ -313,7 +313,7 @@ storeOpOperands.push_back(leadingDim32); // Create nvvm.mma_store op. - if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) { + if (numRow == 16 && numColumn == 16) { rewriter.create(loc, storeOpOperands); } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); @@ -332,7 +332,7 @@ storeOpOperands.push_back(leadingDim32); // Create nvvm.mma_store op. - if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) + if (numRow == 16 && numColumn == 16) rewriter.create(loc, storeOpOperands); else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); @@ -384,13 +384,10 @@ // choose which intrinsic this op will be lowered to. gpu::MMAMatrixType aType = subgroupMmaComputeOp.opA().getType().cast(); - ArrayRef aTypeShape = aType.getShape(); gpu::MMAMatrixType bType = - subgroupMmaComputeOp.opA().getType().cast(); - ArrayRef bTypeShape = bType.getShape(); + subgroupMmaComputeOp.opB().getType().cast(); gpu::MMAMatrixType cType = - subgroupMmaComputeOp.opA().getType().cast(); - ArrayRef cTypeShape = cType.getShape(); + subgroupMmaComputeOp.opC().getType().cast(); gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands); if (subgroupMmaComputeOp.opC() @@ -401,8 +398,9 @@ unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty); unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty); - if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && - bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { + if (aType.getRow() == 16 && aType.getColumn() == 16 && + bType.getRow() == 16 && bType.getColumn() == 16 && + cType.getRow() == 16 && cType.getColumn() == 16) { // Create nvvm.wmma.mma op. NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp = rewriter.create(loc, fragArrayCDTy, @@ -421,8 +419,9 @@ unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty); unpackOp(C, transformedOperands.opC(), 8, f32Ty); - if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && - bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { + if (aType.getRow() == 16 && aType.getColumn() == 16 && + bType.getRow() == 16 && bType.getColumn() == 16 && + cType.getRow() == 16 && cType.getColumn() == 16) { // Create nvvm.wmma.mma op. NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp = rewriter.create( Index: mlir/lib/Dialect/GPU/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/GPU/CMakeLists.txt +++ mlir/lib/Dialect/GPU/CMakeLists.txt @@ -38,6 +38,7 @@ DEPENDS MLIRGPUOpsIncGen + MLIRGPUOpsEnumsGen MLIRGPUOpInterfacesIncGen MLIRGPUPassIncGen MLIRParallelLoopMapperAttrGen Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -28,48 +28,76 @@ using namespace mlir; using namespace mlir::gpu; +namespace mlir { + +namespace gpu { +/// MMAMatrixType storage and uniquing. MMAMatrixType is uniqued based on +/// its number of row/column, type and fragType. +struct MMAMatrixStorageType : public TypeStorage { + MMAMatrixStorageType(int64_t row, int64_t column, Type elementType, + MMAFragType mmaType) + : row(row), column(column), mmaType(mmaType), elementType(elementType) {} + + /// The hash key for uniquing. + using KeyTy = std::tuple; + bool operator==(const KeyTy &key) const { + return key == KeyTy(row, column, elementType, mmaType); + } + + /// Construction. + static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + MMAMatrixStorageType(std::get<0>(key), std::get<1>(key), + std::get<2>(key), std::get<3>(key)); + } + + int64_t row; + int64_t column; + MMAFragType mmaType; + Type elementType; +}; +} // namespace gpu +} // namespace mlir + //===----------------------------------------------------------------------===// // MMAMatrixType //===----------------------------------------------------------------------===// -MMAMatrixType MMAMatrixType::get(ArrayRef shape, Type elementType, - StringRef operand) { - return Base::get(elementType.getContext(), shape, elementType, operand); +MMAMatrixType MMAMatrixType::get(int64_t row, int64_t coloumn, Type elementType, + MMAFragType fragType) { + return Base::get(elementType.getContext(), row, coloumn, elementType, + fragType); } MMAMatrixType MMAMatrixType::getChecked(function_ref emitError, - ArrayRef shape, Type elementType, - StringRef operand) { - return Base::getChecked(emitError, elementType.getContext(), shape, - elementType, operand); + int64_t row, int64_t coloumn, Type elementType, + MMAFragType fragType) { + return Base::getChecked(emitError, elementType.getContext(), row, coloumn, + elementType, fragType); } -unsigned MMAMatrixType::getNumDims() const { return getImpl()->numDims; } +int64_t MMAMatrixType::getRow() const { return getImpl()->row; } -ArrayRef MMAMatrixType::getShape() const { - return getImpl()->getShape(); -} +int64_t MMAMatrixType::getColumn() const { return getImpl()->column; } Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } -StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } +MMAFragType MMAMatrixType::getMMAFragType() const { return getImpl()->mmaType; } bool MMAMatrixType::isValidElementType(Type elementType) { return elementType.isF16() || elementType.isF32(); } LogicalResult -MMAMatrixType::verify(function_ref emitError, - ArrayRef shape, Type elementType, - StringRef operand) { - if (!operand.equals("AOp") && !operand.equals("BOp") && - !operand.equals("COp") && !operand.equals("DOp")) +MMAMatrixType::verify(function_ref emitError, int64_t row, + int64_t column, Type elementType, + MMAFragType mmaFragType) { + if (mmaFragType != MMAFragType::AOp && mmaFragType != MMAFragType::BOp && + mmaFragType != MMAFragType::COp && mmaFragType != MMAFragType::DOp) return emitError() << "operand expected to be one of AOp, BOp, COp or DOp"; - if (shape.size() != 2) - return emitError() << "MMAMatrixType must have exactly two dimensions"; - if (!MMAMatrixType::isValidElementType(elementType)) return emitError() << "MMAMatrixType elements must be F16 or F32"; @@ -139,14 +167,27 @@ StringRef operand; if (failed(parser.parseOptionalString(&operand))) return nullptr; + Optional fragType = symbolizeMMAFragType(operand); + if (!fragType) { + parser.emitError(parser.getNameLoc(), + "operand expected to be one of AOp, BOp, COp or DOp"); + return nullptr; + } // Parse '>'. if (parser.parseGreater()) return nullptr; + if (shape.size() != 2) { + parser.emitError(parser.getNameLoc(), + "MMAMatrixType must have exactly two dimensions"); + return nullptr; + } + return MMAMatrixType::getChecked(mlir::detail::getDefaultDiagnosticEmitFn( parser.getEncodedSourceLoc(beginLoc)), - shape, elementType, operand); + shape[0], shape[1], elementType, + *fragType); } parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword); @@ -158,11 +199,10 @@ .Case([&](Type) { os << "async.token"; }) .Case([&](MMAMatrixType fragTy) { os << "mma_matrix<"; - auto shape = fragTy.getShape(); - for (auto dim = shape.begin(), e = shape.end() - 1; dim != e; ++dim) - os << *dim << 'x'; - os << shape.back() << 'x' << fragTy.getElementType(); - os << ", \"" << fragTy.getOperand() << "\"" << '>'; + os << fragTy.getRow() << 'x' << fragTy.getColumn(); + os << 'x' << fragTy.getElementType(); + os << ", \"" << stringifyMMAFragType(fragTy.getMMAFragType()) << "\"" + << '>'; }) .Default([](Type) { llvm_unreachable("unexpected 'gpu' type kind"); }); } @@ -982,10 +1022,10 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(SubgroupMmaLoadMatrixOp op) { - auto srcType = op.srcMemref().getType(); - auto resType = op.res().getType(); + Type srcType = op.srcMemref().getType(); + Type resType = op.res().getType(); auto resMatrixType = resType.cast(); - auto operand = resMatrixType.getOperand(); + MMAFragType mmaFragType = resMatrixType.getMMAFragType(); auto srcMemrefType = srcType.cast(); auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt(); @@ -999,8 +1039,8 @@ "source memorySpace kGenericMemorySpace, kSharedMemorySpace or " "kGlobalMemorySpace only allowed"); - if (!operand.equals("AOp") && !operand.equals("BOp") && - !operand.equals("COp")) + if (mmaFragType != MMAFragType::AOp && mmaFragType != MMAFragType::BOp && + mmaFragType != MMAFragType::COp) return op.emitError("only AOp, BOp and COp can be loaded"); return success(); @@ -1027,7 +1067,7 @@ "destination memorySpace of kGenericMemorySpace, " "kGlobalMemorySpace or kSharedMemorySpace only allowed"); - if (!srcMatrixType.getOperand().equals("DOp")) + if (srcMatrixType.getMMAFragType() != MMAFragType::DOp) return op.emitError( "expected the operand matrix being stored to have 'DOp' operand type"); @@ -1049,18 +1089,14 @@ }; populateOpInfo(); - if (!opTypes[A].getOperand().equals("AOp") || - !opTypes[B].getOperand().equals("BOp") || - !opTypes[C].getOperand().equals("COp")) + if (opTypes[A].getMMAFragType() != MMAFragType::AOp || + opTypes[B].getMMAFragType() != MMAFragType::BOp || + opTypes[C].getMMAFragType() != MMAFragType::COp) return op.emitError("operands must be in the order AOp, BOp, COp"); - ArrayRef aShape, bShape, cShape; - aShape = opTypes[A].getShape(); - bShape = opTypes[B].getShape(); - cShape = opTypes[C].getShape(); - - if (aShape[1] != bShape[0] || aShape[0] != cShape[0] || - bShape[1] != cShape[1]) + if (opTypes[A].getColumn() != opTypes[B].getRow() || + opTypes[A].getRow() != opTypes[C].getRow() || + opTypes[B].getColumn() != opTypes[C].getColumn()) return op.emitError("operand shapes do not satisfy matmul constraints"); return success(); @@ -1070,3 +1106,5 @@ #define GET_OP_CLASSES #include "mlir/Dialect/GPU/GPUOps.cpp.inc" + +#include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc"