diff --git a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h @@ -0,0 +1,24 @@ +//===- NVVMTOLLVMPass.h - Convert NVVM to LLVM dialect --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ +#define MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_CONVERTNVVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" + +} // namespace mlir + +#endif // MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -40,6 +40,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -709,6 +709,21 @@ ]; } +//===----------------------------------------------------------------------===// +// NVVMToLLVM +//===----------------------------------------------------------------------===// + +def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> { + let summary = "Convert NVVM dialect to LLVM dialect"; + let description = [{ + This pass generates inline assembly for the NVVM ops which is not + implemented in LLVM core. + }]; + let dependentDialects = [ + "NVVM::NVVMDialect", + ]; +} + //===----------------------------------------------------------------------===// // NVGPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -52,6 +52,8 @@ mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions) mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(NVVMOpsInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(NVVMOpsInterface.cpp.inc -gen-op-interface-defs) mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm) mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm) add_public_tablegen_target(MLIRNVVMConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -26,6 +26,8 @@ namespace mlir { namespace NVVM { +#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.h.inc" + /// NVVM memory space identifiers. enum NVVMMemorySpace { /// Global memory space identifier. diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -81,6 +81,55 @@ let mnemonic = attrMnemonic; } +//===----------------------------------------------------------------------===// +// Basic PTX Builder Interface +//===----------------------------------------------------------------------===// + +def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> { + let description = [{ + Interface to generate inline assembly with PTX for basic operations. + + It uses `getPtx` interface function to read PTX code. The `hasSideEffect` + is used to check if the PTX code has any side effect on the memory. + + The order of PTX arguments is started from the results, followed by the + operands. The order of arguments is the same as the order of the operands. + + Example: + + If we have following Op definition: + ```tablegen + def NVVM_MyOp : NVVM_Op<\"myop\", + [DeclareOpInterfaceMethods]>, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_any:$arg1, I32:$arg2)> { + let extraClassDefinition = [{ + bool $cppClass::hasSideEffect() { return true; } + const char* $cppClass::getPtx() { return\"ptx.code %0, [%1], %2;\";} + }\] + ``` + + The NVVM Op will look like below: + ```mlir + nvvm.op %1, %2 : !llvm.ptr, i32 -> i32 + ``` + + The `convert-nvvm-to-llvm` Pass returns the PTX code below. The order of + arguments are kept the same. The read/write modifiers are set based on the + input and result types. + ```mlir + llvm.inline_asm has_side_effects asm_dialect = att + "ptx.code %0, [%1], %2;", "=r,l,r" %1, %2 : (!llvm.ptr, i32) -> i32 + ``` + }]; + let methods = [ + InterfaceMethod<"Returns PTX code", "const char* ", "getPtx">, + + InterfaceMethod<"Returns memory side effect of the PTX code", + "bool", "hasSideEffect">, + ]; +} + //===----------------------------------------------------------------------===// // NVVM intrinsic operations //===----------------------------------------------------------------------===// @@ -249,6 +298,62 @@ let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)"; } +def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx", + [DeclareOpInterfaceMethods]>, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> { + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; + let extraClassDefinition = [{ + bool $cppClass::hasSideEffect() { return true; } + const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"; } + }]; +} + +def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared", + [DeclareOpInterfaceMethods]>, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> { + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; + let extraClassDefinition = [{ + bool $cppClass::hasSideEffect() { return true; } + const char* $cppClass::getPtx() { return "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"; } + }]; +} + +def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity", + [DeclareOpInterfaceMethods]>, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> { + let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; + let extraClassDefinition = [{ + bool $cppClass::hasSideEffect() { return true; } + const char* $cppClass::getPtx() { + return "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}"; + } + }]; +} + +def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared", + [DeclareOpInterfaceMethods]>, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> { + let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; + let extraClassDefinition = [{ + bool $cppClass::hasSideEffect() { return true; } + const char* $cppClass::getPtx() { + return "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}"; + } + }]; +} + def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">, Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> { diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -30,6 +30,7 @@ add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) add_subdirectory(NVGPUToNVVM) +add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) diff --git a/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRNVVMToLLVM + NVVMToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVVMToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRNVVMDialect + MLIRNVGPUDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -0,0 +1,196 @@ +//===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation NVVM ops which is not supported in LLVM +// core. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "nvvm-to-llvm" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DBGSNL() (llvm::dbgs() << "\n") + +namespace mlir { +#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace NVVM; + +#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.cpp.inc" +namespace { + +enum InputModifier { read, readwrite, write }; +class PtxBuilder { + Operation *op; + PatternRewriter &rewriter; + const char *asmStr; + SmallVector asmVals; + std::string asmConstraints; + bool sideEffects; + bool hasResult = false; + + // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints + char getRegisterType(Value v) { + if (v.getDefiningOp()) + return 'n'; + if (v.getType().isInteger(16)) + return 'h'; + if (v.getType().isInteger(32)) + return 'r'; + if (v.getType().isInteger(64)) + return 'l'; + if (v.getType().isF32()) + return 'f'; + if (v.getType().isF64()) + return 'd'; + if (auto ptr = v.getType().dyn_cast()) { + // Shared address spaces is addressed with 32-bit pointers. + if (ptr.getAddressSpace() == NVVM::kSharedMemorySpace) { + return 'r'; + } + return 'l'; + } + assert(false && "Register type is not handled yet"); + return ' '; + } + + Value makeConstantValue(unsigned v) { + return rewriter.create( + op->getLoc(), + IntegerAttr::get(IntegerType::get(op->getContext(), 32), v)); + } + +public: + PtxBuilder(Operation *op, PatternRewriter &rewriter, const char *ptxAsm, + bool sideEffects = false) + : op(op), rewriter(rewriter), asmStr(ptxAsm), sideEffects(sideEffects) {} + + void insertValue(Value v, InputModifier itype = InputModifier::read) { + llvm::raw_string_ostream ss(asmConstraints); + if (itype == InputModifier::read) { + asmVals.push_back(v); + } else if (itype == InputModifier::readwrite) { + asmVals.push_back(v); + ss << "+"; + hasResult = true; + } else if (itype == InputModifier::write) { + ss << "="; + hasResult = true; + } + ss << getRegisterType(v) << ","; + ss.flush(); + } + + void insertValue(unsigned v, InputModifier itype = InputModifier::read) { + insertValue(makeConstantValue(v), itype); + } + + LLVM::InlineAsmOp build() { + auto asmDialectAttr = + LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT); + Type resultType = hasResult ? op->getResult(0).getType() + : LLVM::LLVMVoidType::get(op->getContext()); + + // Remove the last comma from the constraints string. + if (asmConstraints[asmConstraints.size() - 1] == ',') + asmConstraints.pop_back(); + + return rewriter.create( + op->getLoc(), resultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints.data(), + /*has_side_effects=*/sideEffects, + /*is_align_stack=*/false, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); + } + + void buildAndReplaceOp() { + LLVM::InlineAsmOp inlineAsmOp = build(); + LLVM_DEBUG(DBGS() << inlineAsmOp); + if (inlineAsmOp->getNumResults() == op->getNumResults()) + rewriter.replaceOp(op, inlineAsmOp); + else + rewriter.eraseOp(op); + } +}; + +struct PtxLowering + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern< + NVVM::BasicPtxBuilderInterface>::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(NVVM::BasicPtxBuilderInterface op, + PatternRewriter &rewriter) const override { + PtxBuilder generator(op, rewriter, op.getPtx(), op.hasSideEffect()); + LLVM_DEBUG(DBGS() << "Ptx Builder Lowering :\n" << op); + + LLVM_DEBUG(DBGS() << "Results:\n"); + for (auto val : op->getResults()) { + LLVM_DEBUG(DBGS() << val); + generator.insertValue(val, write); + } + + LLVM_DEBUG(DBGS() << "Operands:\n"); + for (auto val : op->getOperands()) { + LLVM_DEBUG(DBGS() << val); + generator.insertValue(val); + } + + generator.buildAndReplaceOp(); + return success(); + } +}; + +struct ConvertNVVMToLLVMPass + : public impl::ConvertNVVMToLLVMPassBase { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext()); + ConversionTarget target(getContext()); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s + +// CHECK-LABEL : @init_mbarrier_arrive_expect_tx +llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) -> i32{ + //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32 + %res = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i32 + llvm.return %res : i32 +} + +// CHECK-LABEL : @init_mbarrier_arrive_expect_tx_generic +llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32)-> i32 { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32 + %res = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i32 + llvm.return %res : i32 +} + +// CHECK-LABEL : @init_mbarrier_try_wait.parity.shared +llvm.func @init_mbarrier_try_wait.parity.shared(%barrier : !llvm.ptr<3>, %token : i32) -> i32 { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32 + %res = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32 + llvm.return %res : i32 +} + +// CHECK-LABEL : @init_mbarrier_try_wait.parity +llvm.func @init_mbarrier_try_wait.parity(%barrier : !llvm.ptr, %token : i32) -> i32{ + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,l,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32 + %res = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32 + llvm.return %res : i32 +}