diff --git a/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h @@ -0,0 +1,26 @@ +//===- NVVMTOLLVMPass.h - Convert NVVM to LLVM dialect --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ +#define MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +#define GEN_PASS_DECL_CONVERTNVVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" + +void populateNVVMToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_NVVMTOLLVM_NVVMTOLLVMPASS_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -40,6 +40,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -709,6 +709,21 @@ ]; } +//===----------------------------------------------------------------------===// +// NVVMToLLVM +//===----------------------------------------------------------------------===// + +def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> { + let summary = "Convert NVVM dialect to LLVM dialect"; + let description = [{ + This pass generates inline assembly for the NVVM ops which is not + implemented in LLVM core. + }]; + let dependentDialects = [ + "NVVM::NVVMDialect", + ]; +} + //===----------------------------------------------------------------------===// // NVGPUToNVVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -245,6 +245,30 @@ let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands)"; } +def NVVM_MBarrierArriveExpectTxOp : NVVM_Op<"mbarrier.arrive.expect_tx">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_any:$addr, I32:$txcount)> { + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; +} + +def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_Op<"mbarrier.arrive.expect_tx.shared">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_shared:$addr, I32:$txcount)> { + let assemblyFormat = "$addr `,` $txcount attr-dict `:` type(operands) `->` type($res)"; +} + +def NVVM_MBarrierTryWaitParityOp : NVVM_Op<"mbarrier.try_wait.parity">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_any:$addr, LLVM_Type:$token)> { + let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; +} + +def NVVM_MBarrierTryWaitParitySharedOp : NVVM_Op<"mbarrier.try_wait.parity.shared">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_i64ptr_shared:$addr, LLVM_Type:$token)> { + let assemblyFormat = "$addr `,` $token attr-dict `:` type(operands) `->` type($res)"; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -30,6 +30,7 @@ add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) add_subdirectory(NVGPUToNVVM) +add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) diff --git a/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt b/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/NVVMToLLVM/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_conversion_library(MLIRNVVMToLLVM + NVVMToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVVMToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRNVVMDialect + MLIRNVGPUDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -0,0 +1,226 @@ +//===- NVVMToLLVM.cpp - NVVM to LLVM dialect conversion -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation NVVM ops which is not supported in LLVM +// core. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct ConvertNVVMToLLVMPass + : public impl::ConvertNVVMToLLVMPassBase { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + LowerToLLVMOptions options(&getContext()); + RewritePatternSet patterns(&getContext()); + LLVMTypeConverter converter(&getContext(), options); + IRRewriter rewriter(&getContext()); + populateNVVMToLLVMConversionPatterns(converter, patterns); + LLVMConversionTarget target(getContext()); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +enum InputType { read, readwrite, write }; +class GenerateInlineAsm { + const char *asmStr; + SmallVector asmVals; + std::string asmConstraints; + bool sideEffects; + bool hasResult = false; + + // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints + char getRegisterType(Value v) { + if (v.getType().isInteger(16)) + return 'h'; + if (v.getType().isInteger(32)) + return 'r'; + if (v.getType().isInteger(64)) + return 'l'; + if (v.getType().isF32()) + return 'f'; + if (v.getType().isF64()) + return 'd'; + if (v.getType().isa()) + return 'r'; + assert(false && "Register type is not handled yet"); + return ' '; + } + +public: + GenerateInlineAsm(const char *ptxBlock, bool sideEffects = false) + : asmStr(ptxBlock), sideEffects(sideEffects) {} + + void insertValue(Value v, InputType itype) { + llvm::raw_string_ostream ss(asmConstraints); + if (itype == InputType::read) { + asmVals.push_back(v); + } else if (itype == InputType::readwrite) { + asmVals.push_back(v); + ss << "+"; + hasResult = true; + } else if (itype == InputType::write) { + ss << "="; + hasResult = true; + } + ss << getRegisterType(v) << ","; + ss.flush(); + } + + LLVM::InlineAsmOp generate(ConversionPatternRewriter &rewriter, + Operation *op) { + auto asmDialectAttr = + LLVM::AsmDialectAttr::get(op->getContext(), LLVM::AsmDialect::AD_ATT); + Type resultType = hasResult ? op->getResult(0).getType() + : LLVM::LLVMVoidType::get(op->getContext()); + + // Remove the last comma from the constraints string. + if (asmConstraints[asmConstraints.size() - 1] == ',') + asmConstraints.pop_back(); + + return rewriter.create( + op->getLoc(), resultType, + /*operands=*/asmVals, + /*asm_string=*/asmStr, + /*constraints=*/asmConstraints.data(), + /*has_side_effects=*/sideEffects, + /*is_align_stack=*/false, + /*asm_dialect=*/asmDialectAttr, + /*operand_attrs=*/ArrayAttr()); + } +}; + +struct MBarrierArriveExpectTxSharedOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + NVVM::MBarrierArriveExpectTxSharedOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(NVVM::MBarrierArriveExpectTxSharedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const char *ptx = "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;"; + GenerateInlineAsm generator(ptx, true); + generator.insertValue(op.getRes(), write); + generator.insertValue(op.getAddr(), read); + generator.insertValue(op.getTxcount(), read); + auto newOp = generator.generate(rewriter, op); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct MBarrierArriveExpectTxOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + NVVM::MBarrierArriveExpectTxOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(NVVM::MBarrierArriveExpectTxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const char *ptx = "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;"; + GenerateInlineAsm generator(ptx, true); + generator.insertValue(op.getRes(), write); + generator.insertValue(op.getAddr(), read); + generator.insertValue(op.getTxcount(), read); + auto newOp = generator.generate(rewriter, op); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct MBarrierTryWaitParitySharedOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + NVVM::MBarrierTryWaitParitySharedOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(NVVM::MBarrierTryWaitParitySharedOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const char *ptx = "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}"; + GenerateInlineAsm generator(ptx, true); + generator.insertValue(op.getRes(), write); + generator.insertValue(op.getAddr(), read); + generator.insertValue(op.getToken(), read); + auto newOp = generator.generate(rewriter, op); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct MBarrierTryWaitParityOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + NVVM::MBarrierTryWaitParityOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(NVVM::MBarrierTryWaitParityOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const char *ptx = "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}"; + GenerateInlineAsm generator(ptx, true); + generator.insertValue(op.getRes(), write); + generator.insertValue(op.getAddr(), read); + generator.insertValue(op.getToken(), read); + auto newOp = generator.generate(rewriter, op); + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +} // namespace + +void mlir::populateNVVMToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add< + MBarrierArriveExpectTxSharedOpLowering, MBarrierArriveExpectTxOpLowering, + MBarrierTryWaitParitySharedOpLowering, MBarrierTryWaitParityOpLowering>( + converter); +} diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt --convert-nvvm-to-llvm --split-input-file %s | FileCheck %s + +llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32) { + //CHECK : llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 %0, [%1], %2;", "=l,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i64 + %8 = nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32 -> i64 + llvm.return +} + +llvm.func @init_mbarrier_arrive_expect_tx_generic(%barrier : !llvm.ptr, %txcount : i32) { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.b64 %0, [%1], %2;", "=l,r,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i64 + %8 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr, i32 -> i64 + llvm.return +} + +llvm.func @init_mbarrier_try_wait.parity.shared(%barrier : !llvm.ptr<3>, %token : i32) { + // CHECK : llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.shared.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %arg0, %arg1 : (!llvm.ptr<3>, i32) -> i32 + %11 = nvvm.mbarrier.try_wait.parity.shared %barrier, %token : !llvm.ptr<3>, i32 -> i32 + llvm.return +} + +llvm.func @init_mbarrier_try_wait.parity(%barrier : !llvm.ptr, %token : i32) { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .pred P1; \0A\09mbarrier.try_wait.parity.b64 P1, [%1], %2; \0A\09selp.b32 %0, 1, 0, P1; \0A\09}", "=r,r,r" %arg0, %arg1 : (!llvm.ptr, i32) -> i32 + %11 = nvvm.mbarrier.try_wait.parity %barrier, %token : !llvm.ptr, i32 -> i32 + llvm.return +}