diff --git a/mlir/include/mlir/Conversion/NVGPUToNVVM/BasicPtxBuilder.h b/mlir/include/mlir/Conversion/NVGPUToNVVM/BasicPtxBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/NVGPUToNVVM/BasicPtxBuilder.h @@ -0,0 +1,57 @@ +//===- BasicPtxBuilder.h - Tool to generate Inline Assembly -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_BASICPTXBUILDER_H_ +#define MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_MLIR_CONVERSION_BASICPTXBUILDER_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace nvgpu { + +enum PTXRegisterMod { + Read, + Write, + ReadWrite, +}; + +class PtxBuilder { + Operation *op; + PatternRewriter &rewriter; + const char *asmStr; + SmallVector asmVals; + std::string asmConstraints; + bool sideEffects; + bool hasResult = false; + + char getRegisterType(Value v); + + Value makeConstant(unsigned val) { + return rewriter.create(op->getLoc(), + rewriter.getIntegerType(32), val); + } + +public: + PtxBuilder(Operation *op, PatternRewriter &rewriter, const char *ptxAsm, + bool sideEffects = false) + : op(op), rewriter(rewriter), asmStr(ptxAsm), sideEffects(sideEffects) {} + + void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read); + + void insertValue(unsigned v) { insertValue(makeConstant(v)); } + + LLVM::InlineAsmOp build(); + + void buildAndReplaceOp(); +}; +} // namespace nvgpu +} // namespace mlir + +#endif // MLIR_CONVERSION_BASICPTXBUILDER_H_ diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/NVGPU/IR/CMakeLists.txt @@ -1,2 +1,10 @@ add_mlir_dialect(NVGPU nvgpu) add_mlir_doc(NVGPU NVGPU Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS NVGPU.td) +mlir_tablegen(NVGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(NVGPUEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(NVGPUAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=nvgpu) +mlir_tablegen(NVGPUAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvgpu) +add_public_tablegen_target(MLIRNVGPUEnumsIncGen) +add_dependencies(mlir-headers MLIRNVGPUEnumsIncGen) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" def NVGPU_Dialect : Dialect { let name = "nvgpu"; @@ -77,6 +78,28 @@ }]; } +//===----------------------------------------------------------------------===// +// NVGPU Attribute Definitions +//===----------------------------------------------------------------------===// + +// https://docs.nvidia.com/cuda/parallel-thread-execution/#id62 +def LoadCacheModifierCA : I32BitEnumAttrCaseBit<"CA", 0, "ca">; +def LoadCacheModifierCG : I32BitEnumAttrCaseBit<"CG", 1, "cg">; +def LoadCacheModifierCS : I32BitEnumAttrCaseBit<"CS", 2, "cs">; +def LoadCacheModifierLU : I32BitEnumAttrCaseBit<"LU", 3, "lu">; +def LoadCacheModifierCV : I32BitEnumAttrCaseBit<"CV", 4, "cv">; + +/// Enum attribute of the different kinds. +def LoadCacheModifierKind : I32BitEnumAttr<"LoadCacheModifierKind", + "NVGPU load cache modifier kind", + [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS, + LoadCacheModifierLU, LoadCacheModifierCV]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::nvgpu"; +} + +def LoadCacheModifierAttr : EnumAttr; + //===----------------------------------------------------------------------===// // NVGPU Op Definitions //===----------------------------------------------------------------------===// @@ -249,7 +272,7 @@ `nvgpu.device_async_wait` to synchronize copies as explained in those ops descriptions. - `bypassL1` attribute is hint to the hardware to bypass the L1 cache during + `cacheModifier` attribute is hint to the hardware to bypass the L1 cache during async copy, this hint may be ignored by the hardware. `dstElements` attribute is the total number of elements written to @@ -295,9 +318,9 @@ Variadic:$srcIndices, IndexAttr:$dstElements, Optional:$srcElements, - OptionalAttr:$bypassL1); + DefaultValuedAttr:$cacheModifier); let assemblyFormat = [{ - $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements (`,` $srcElements^)? + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $dstElements `,` `cache` `=` $cacheModifier (`,` $srcElements^)? attr-dict `:` type($src) `to` type($dst) }]; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -19,6 +19,11 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.h.inc" diff --git a/mlir/lib/Conversion/NVGPUToNVVM/BasicPtxBuilder.cpp b/mlir/lib/Conversion/NVGPUToNVVM/BasicPtxBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/NVGPUToNVVM/BasicPtxBuilder.cpp @@ -0,0 +1,87 @@ +//===- NVGPUToLLVM.cpp - NVGPU to LLVM dialect conversion -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// PTX Builder with inline Assembly +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/NVGPUToNVVM/BasicPtxBuilder.h" +#include + +using namespace mlir; +using namespace nvgpu; + +// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints +char mlir::nvgpu::PtxBuilder::getRegisterType(Value v) { + if (v.getDefiningOp()) + return 'n'; + if (v.getType().isInteger(16)) + return 'h'; + if (v.getType().isInteger(32)) + return 'r'; + if (v.getType().isInteger(64)) + return 'l'; + if (v.getType().isF32()) + return 'f'; + if (v.getType().isF64()) + return 'd'; + if (auto ptr = v.getType().dyn_cast()) { + // Shared address spaces is addressed with 32-bit pointers. + if (ptr.getAddressSpace() == mlir::NVVM::kSharedMemorySpace) { + return 'r'; + } + return 'l'; + } + assert(false && "Register type is not handled yet"); + return ' '; +} + +void mlir::nvgpu::PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { + llvm::raw_string_ostream ss(asmConstraints); + if (itype == PTXRegisterMod::Read) { + asmVals.push_back(v); + } else if (itype == PTXRegisterMod::ReadWrite) { + asmVals.push_back(v); + ss << "+"; + hasResult = true; + } else if (itype == PTXRegisterMod::Write) { + ss << "="; + hasResult = true; + } + ss << getRegisterType(v) << ","; + ss.flush(); +} + +LLVM::InlineAsmOp mlir::nvgpu::PtxBuilder::build() { + auto asmDialectAttr = + LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT); + Type resultType = hasResult ? op->getResult(0).getType() + : LLVM::LLVMVoidType::get(op->getContext()); + + // Remove the last comma from the constraints string. + if (asmConstraints[asmConstraints.size() - 1] == ',') + asmConstraints.pop_back(); + + return rewriter.create( + op->getLoc(), resultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints.data(), + /*has_side_effects=*/sideEffects, + /*is_align_stack=*/false, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); +} + +void mlir::nvgpu::PtxBuilder::buildAndReplaceOp() { + LLVM::InlineAsmOp inlineAsmOp = build(); + if (inlineAsmOp->getNumResults() == op->getNumResults()) + rewriter.replaceOp(op, inlineAsmOp); + else + rewriter.eraseOp(op); +} diff --git a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRNVGPUToNVVM NVGPUToNVVM.cpp + BasicPtxBuilder.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVGPUToNVVM diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/NVGPUToNVVM/BasicPtxBuilder.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -361,51 +362,6 @@ } }; -static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr, - Value dstBytes, Value srcElements, - mlir::MemRefType elementType, - ConversionPatternRewriter &rewriter) { - auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(), - LLVM::AsmDialect::AD_ATT); - - const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n"; - const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n"; - const char *asmConstraints = "r,l,n,r"; - - Value c3I32 = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3)); - Value bitwidth = rewriter.create( - loc, rewriter.getI32Type(), - rewriter.getI32IntegerAttr(elementType.getElementTypeBitWidth())); - Value srcElementsI32 = - rewriter.create(loc, rewriter.getI32Type(), srcElements); - Value srcBytes = rewriter.create( - loc, rewriter.create(loc, bitwidth, srcElementsI32), c3I32); - - SmallVector asmVals{dstPtr, srcPtr, dstBytes, srcBytes}; - - // Pick the right asm string based on the dstBytes which is a compile-time - // constant. - auto dstByteConstOp = - dyn_cast(dstBytes.getDefiningOp()); - auto dstByteAttr = dyn_cast(dstByteConstOp.getValue()); - int64_t dstByteVal = dstByteAttr.getValue().getSExtValue(); - - assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) && - "cp.async byte copy size must be 4, 8 or 16"); - // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than - // 16 dst bytes. - const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr; - - rewriter.create( - loc, LLVM::LLVMVoidType::get(rewriter.getContext()), - /*operands=*/asmVals, - /*asm_string=*/asmStr, - /*constraints=*/asmConstraints, /*has_side_effects=*/true, - /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr, - /*operand_attrs=*/ArrayAttr()); -} - /// Returns the constraints for the sparse MMA inline assembly instruction. static std::string buildMmaSparseAsmConstraintString(unsigned matASize, unsigned matBSize, @@ -620,30 +576,49 @@ int64_t dstElements = adaptor.getDstElements().getZExtValue(); int64_t sizeInBytes = (dstMemrefType.getElementTypeBitWidth() * dstElements) / 8; - // bypass L1 is only supported for byte sizes of 16, we drop the hint - // otherwise. - UnitAttr bypassL1 = - sizeInBytes == 16 ? adaptor.getBypassL1Attr() : UnitAttr(); - - // When the optional SrcElements argument is present, the source (global - // memory) of CpAsyncOp is read only for SrcElements number of elements. The - // rest of the DstElements in the destination (shared memory) are filled - // with zeros. - if (op.getSrcElements()) - emitCpAsyncOpZfillAsm(loc, dstPtr, scrPtr, - rewriter.create( - loc, rewriter.getI32Type(), - rewriter.getI32IntegerAttr(sizeInBytes)), - adaptor.getSrcElements(), srcMemrefType, rewriter); // When the optional SrcElements argument is *not* present, the regular // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global - // memory) to fill DstElements number of elements in the destination (shared - // memory). - else - rewriter.create(loc, dstPtr, scrPtr, - rewriter.getI32IntegerAttr(sizeInBytes), - bypassL1); + // memory) to fill DstElements number of elements in the destination + // (shared memory). + Value srcBytes = adaptor.getSrcElements(); + if (srcBytes) { + // When the optional SrcElements argument is present, the source (global + // memory) of CpAsyncOp is read only for SrcElements number of elements. + // The rest of the DstElements in the destination (shared memory) are + // filled with zeros. + Value c3I32 = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(3)); + Value bitwidth = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(srcMemrefType.getElementTypeBitWidth())); + Value srcElementsI32 = + rewriter.create(loc, rewriter.getI32Type(), srcBytes); + srcBytes = rewriter.create( + loc, rewriter.create(loc, bitwidth, srcElementsI32), + c3I32); + + std::string ptx = "cp.async."; + ptx += nvgpu::stringifyLoadCacheModifierKind(op.getCacheModifier()); + ptx += ".shared.global [%0], [%1], %2, %3;\n"; + + nvgpu::PtxBuilder ptxBuilder(op, rewriter, ptx.data()); + ptxBuilder.insertValue(dstPtr); + ptxBuilder.insertValue(scrPtr); + ptxBuilder.insertValue(sizeInBytes); + ptxBuilder.insertValue(srcBytes); + ptxBuilder.buildAndReplaceOp(); + return success(); + } + // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than + // 16 dst bytes. + UnitAttr bypassl1 = {}; + if (op.getCacheModifier() == nvgpu::LoadCacheModifierKind::CG && + sizeInBytes == 16) + bypassl1 = rewriter.getUnitAttr(); + + rewriter.create( + loc, dstPtr, scrPtr, rewriter.getI32IntegerAttr(sizeInBytes), bypassl1); // Drop the result token. Value zero = rewriter.create( diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -665,7 +665,7 @@ foldedDstIndices, (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), - copyOp.getBypassL1Attr()); + copyOp.getCacheModifier()); return success(); } diff --git a/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt --- a/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -6,7 +6,8 @@ DEPENDS MLIRNVGPUIncGen - + MLIRNVGPUEnumsIncGen + LINK_LIBS PUBLIC MLIRGPUDialect MLIRIR diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -18,16 +18,24 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::nvgpu; +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" +#include "mlir/Dialect/NVGPU/IR/NVGPUEnums.cpp.inc" + void nvgpu::NVGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/NVGPU/IR/NVGPUTypes.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" @@ -83,6 +91,26 @@ return emitOpError() << "expected " << dstMemref.getRank() << " destination indices, got " << getDstIndices().size(); + if (getCacheModifier() != LoadCacheModifierKind::CA && + getCacheModifier() != LoadCacheModifierKind::CG) { + return emitOpError() << "Not supported cache modifier " + << stringifyLoadCacheModifierKind(getCacheModifier()); + } + if (getCacheModifier() == LoadCacheModifierKind::CG) { + int64_t dstElements = getDstElements().getZExtValue(); + int64_t sizeInBytes = + (dstMemref.getElementTypeBitWidth() * dstElements) / 8; + int64_t req = 16 * 8 / dstMemref.getElementTypeBitWidth(); + if (sizeInBytes != 16) { + return emitOpError() << stringifyLoadCacheModifierKind(getCacheModifier()) + << " cache does not satify alignment for " + << dstMemref << " with destination element " + << dstElements + << ". Set CA cache or set " + "destination element to " + << req; + } + } return success(); } @@ -303,7 +331,8 @@ // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// -#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"