Index: mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -20,51 +20,6 @@ namespace { -/// Contains all the common LLVM types which are used across the lowerings of -/// GPU subgroup ops to NVVM dialect. -struct CommonLLVMAndBuiltInMLIRTypes { -public: - CommonLLVMAndBuiltInMLIRTypes(MLIRContext *context) { - numHalfsInOpFrags.resize(4); - numHalfsInOpFrags[A] = 8; - numHalfsInOpFrags[B] = 8; - numHalfsInOpFrags[C] = 4; - i32Ty = IntegerType::get(context, 32); - f16Ty = FloatType::getF16(context); - f32Ty = FloatType::getF32(context); - f16x2Ty = VectorType::get(2, f16Ty); - fragArrayABTy = LLVM::LLVMStructType::getLiteral( - context, SmallVector(8, f16x2Ty)); - fragArrayCDTy = LLVM::LLVMStructType::getLiteral( - context, SmallVector(4, f16x2Ty)); - fragArrayCDF32Ty = - LLVM::LLVMStructType::getLiteral(context, SmallVector(8, f32Ty)); - }; - - Type i32Ty; - Type f16Ty; - Type f32Ty; - Type f16x2Ty; - /// Type for the fragment of A and B operands that a single thread holds for - /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) + - /// (beta*C). - Type fragArrayABTy; - /// Type for the fragment of C and D operands that a single thread holds for - /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) + - /// (beta*C). - Type fragArrayCDTy; - /// Type for the fragment of C and D operands that a single thread holds for - /// fp32 data type in a WMMA operation of the form D = (alpha*(A*B)) + - /// (beta*C). - Type fragArrayCDF32Ty; - /// Represents the number of f16 elements a single thread holds in a WMMA - /// operation of the form D = (alpha*(A*B)) + (beta*C) . - SmallVector numHalfsInOpFrags; - /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) + - /// (beta*C). - enum OperandMap { A, B, C }; -}; - /// Checks if all the operands of the op being lowered are of LLVM Types. The /// types are expected to be converted by the `LLVMTypeConverter` before the op /// is actually lowered. If the type of an operands is not already converted it @@ -85,18 +40,32 @@ static constexpr StringRef kInvalidCaseStr = "Unimplemented WMMA variant, Only M16N16K16 version implemented."; +/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. +static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { + StringRef operandStr = type.getOperand(); + assert(type.getElementType().isa()); + Type baseType = type.getElementType().isF16() + ? VectorType::get(2, type.getElementType()) + : type.getElementType(); + auto getLLVMType = [&](int64_t numElements) { + return LLVM::LLVMStructType::getLiteral( + type.getContext(), SmallVector(numElements, baseType)); + }; + if (operandStr.equals("AOp") || operandStr.equals("BOp")) + return getLLVMType(8); + if (type.getElementType().isF16()) + return getLLVMType(4); + return getLLVMType(8); +} + /// This class implements the conversion of GPU MMA loadOp to wmma.load op /// in the NVVM dialect. The conversion not only emits the NVVM op but also /// emits code that is necessary to store the data in the destination memref /// after it has been loaded. struct WmmaLoadOpToNVVMLowering - : public ConvertOpToLLVMPattern, - private CommonLLVMAndBuiltInMLIRTypes { -public: - explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter) - : ConvertOpToLLVMPattern(typeConverter), - CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) { - } + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::SubgroupMmaLoadMatrixOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, @@ -134,6 +103,7 @@ SmallVector indices(subgroupMmaLoadMatrixOp.indices()); Value srcOffsetIVal = indices[0]; Value srcOffsetJVal = indices[1]; + Type i32Ty = rewriter.getI32Type(); Value leadingDim32 = rewriter.create(loc, i32Ty, leadDimension); Value numElemsLeadDim = @@ -147,7 +117,8 @@ promotedSrcOpToUse); Value loadAddress = rewriter.create( loc, - LLVM::LLVMPointerType::get(f16Ty, srcMemrefType.getMemorySpaceAsInt()), + LLVM::LLVMPointerType::get(rewriter.getF16Type(), + srcMemrefType.getMemorySpaceAsInt()), promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); // Bitcast the base address pointer of the destination memref, So that @@ -164,51 +135,33 @@ subgroupMmaLoadMatrixOp.res().getType().cast(); ArrayRef retTypeShape = retType.getShape(); - Type resType; + Type resType = convertMMAToLLVMType(retType); StringRef operandStr = retType.getOperand(); - if (operandStr.equals("AOp") || operandStr.equals("BOp")) { - resType = fragArrayABTy; - } else { - if (srcMemrefType.getElementType().isF16()) - resType = fragArrayCDTy; - else if (srcMemrefType.getElementType().isF32()) - resType = fragArrayCDF32Ty; - else - return failure(); - } // Create nvvm.mma_load op according to the operand types. SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); if (operandStr.equals("AOp")) { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { - NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp = - rewriter.create(loc, resType, - loadOpOperands); - rewriter.replaceOp(op, wmmaLoadAOp.getResult()); + rewriter.replaceOpWithNewOp(op, resType, + loadOpOperands); } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } } else if (operandStr.equals("BOp")) { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { - NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp = - rewriter.create(loc, resType, - loadOpOperands); - rewriter.replaceOp(op, wmmaLoadBOp.getResult()); + rewriter.replaceOpWithNewOp(op, resType, + loadOpOperands); } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } } else { if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { if (srcMemrefType.getElementType().isF16()) { - NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp = - rewriter.create(loc, resType, - loadOpOperands); - rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + rewriter.replaceOpWithNewOp( + op, resType, loadOpOperands); } else if (srcMemrefType.getElementType().isF32()) { - NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp = - rewriter.create(loc, resType, - loadOpOperands); - rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + rewriter.replaceOpWithNewOp( + op, resType, loadOpOperands); } } else { return rewriter.notifyMatchFailure(op, kInvalidCaseStr); @@ -223,13 +176,9 @@ /// emits code that is necessary to unpack the data in the source and /// convert the data in the format that is needed by the NVVM op. struct WmmaStoreOpToNVVMLowering - : public ConvertOpToLLVMPattern, - private CommonLLVMAndBuiltInMLIRTypes { -public: - explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter) - : ConvertOpToLLVMPattern(typeConverter), - CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) { - } + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::SubgroupMmaStoreMatrixOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, @@ -266,6 +215,7 @@ SmallVector indices(subgroupMmaStoreMatrixOp.indices()); Value dstOffsetIVal = indices[0]; Value dstOffsetJVal = indices[1]; + Type i32Ty = rewriter.getI32Type(); Value leadingDim32 = rewriter.create(loc, i32Ty, leadDimension); Value numElemsLeadDim = @@ -279,7 +229,8 @@ promotedDstOpToUse); Value storeAddress = rewriter.create( loc, - LLVM::LLVMPointerType::get(f16Ty, dstMemrefType.getMemorySpaceAsInt()), + LLVM::LLVMPointerType::get(rewriter.getF16Type(), + dstMemrefType.getMemorySpaceAsInt()), promotedDstOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); // Bitcast the base address pointer of the destination memref, So that @@ -299,18 +250,16 @@ subgroupMmaStoreMatrixOp.src().getType().cast(); ArrayRef srcTypeShape = srcType.getShape(); + auto matrixType = operands[0].getType().cast(); + for (unsigned i = 0, e = matrixType.getBody().size(); i < e; ++i) { + Value toUse = rewriter.create( + loc, matrixType.getBody()[i], operands[0], + rewriter.getI32ArrayAttr(i)); + storeOpOperands.push_back(toUse); + } + storeOpOperands.push_back(leadingDim32); // Unpack the results from the source. - if (subgroupMmaStoreMatrixOp.src() - .getType() - .cast() - .getElementType() == f16Ty) { - for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) { - Value toUse = rewriter.create( - loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i)); - storeOpOperands.push_back(toUse); - } - storeOpOperands.push_back(leadingDim32); - + if (srcType.getElementType().isF16()) { // Create nvvm.mma_store op. if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) { rewriter.create(loc, storeOpOperands); @@ -319,17 +268,8 @@ } rewriter.eraseOp(op); return success(); - } else if (subgroupMmaStoreMatrixOp.src() - .getType() - .cast() - .getElementType() == f32Ty) { - for (unsigned i = 0, e = 8; i < e; ++i) { - Value toUse = rewriter.create( - loc, f32Ty, operands[0], rewriter.getI32ArrayAttr(i)); - storeOpOperands.push_back(toUse); - } - storeOpOperands.push_back(leadingDim32); - + } + if (srcType.getElementType().isF32()) { // Create nvvm.mma_store op. if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) rewriter.create(loc, storeOpOperands); @@ -339,7 +279,6 @@ rewriter.eraseOp(op); return success(); } - return failure(); } }; @@ -347,12 +286,9 @@ /// This class implements the conversion of GPU MMA computeOp to wmma.mma op /// in the NVVM dialect. struct WmmaMmaOpToNVVMLowering - : public ConvertOpToLLVMPattern, - private CommonLLVMAndBuiltInMLIRTypes { - explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter) - : ConvertOpToLLVMPattern(typeConverter), - CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) { - } + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::SubgroupMmaComputeOp>::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, @@ -370,11 +306,11 @@ // values form lowered memrefs. SmallVector unpackedOps; - auto unpackOp = [&](CommonLLVMAndBuiltInMLIRTypes::OperandMap op, - Value operand, unsigned numElems, Type elemType) { - for (unsigned i = 0; i < numElems; ++i) { + auto unpackOp = [&](Value operand) { + auto structType = operand.getType().cast(); + for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = rewriter.create( - loc, elemType, operand, rewriter.getI32ArrayAttr(i)); + loc, structType.getBody()[i], operand, rewriter.getI32ArrayAttr(i)); unpackedOps.push_back(toUse); } }; @@ -385,55 +321,39 @@ subgroupMmaComputeOp.opA().getType().cast(); ArrayRef aTypeShape = aType.getShape(); gpu::MMAMatrixType bType = - subgroupMmaComputeOp.opA().getType().cast(); + subgroupMmaComputeOp.opB().getType().cast(); ArrayRef bTypeShape = bType.getShape(); gpu::MMAMatrixType cType = - subgroupMmaComputeOp.opA().getType().cast(); + subgroupMmaComputeOp.opC().getType().cast(); ArrayRef cTypeShape = cType.getShape(); gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands); - if (subgroupMmaComputeOp.opC() - .getType() - .cast() - .getElementType() == f16Ty) { - unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty); - unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty); - unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty); + unpackOp(transformedOperands.opA()); + unpackOp(transformedOperands.opB()); + unpackOp(transformedOperands.opC()); + if (cType.getElementType().isF16()) { if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { // Create nvvm.wmma.mma op. - NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp = - rewriter.create(loc, fragArrayCDTy, - unpackedOps); + rewriter.replaceOpWithNewOp( + op, transformedOperands.opC().getType(), unpackedOps); - rewriter.replaceOp(op, wmmaMmaOp.getResult()); return success(); - } else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } - } else if (subgroupMmaComputeOp.opC() - .getType() - .cast() - .getElementType() == f32Ty) { - unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty); - unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty); - unpackOp(C, transformedOperands.opC(), 8, f32Ty); - + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + if (cType.getElementType().isF32()) { if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { // Create nvvm.wmma.mma op. - NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp = - rewriter.create( - loc, fragArrayCDF32Ty, unpackedOps); + rewriter.replaceOpWithNewOp( + op, transformedOperands.opC().getType(), unpackedOps); - rewriter.replaceOp(op, wmmaMmaOp.getResult()); return success(); - } else { - return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); } - return failure(); } }; @@ -443,8 +363,7 @@ namespace mlir { void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.insert(converter); - patterns.insert(converter); - patterns.insert(converter); + patterns.insert(converter); } } // namespace mlir