diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir --- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir @@ -1,22 +1,7 @@ // RUN: mlir-opt %s \ // RUN: -test-transform-dialect-interpreter \ // RUN: -test-transform-dialect-erase-schedule \ -// RUN: -gpu-kernel-outlining \ -// RUN: -convert-scf-to-cf \ -// RUN: -convert-vector-to-llvm \ -// RUN: -convert-math-to-llvm \ -// RUN: -expand-strided-metadata \ -// RUN: -lower-affine \ -// RUN: -convert-index-to-llvm=index-bitwidth=32 \ -// RUN: -convert-arith-to-llvm \ -// RUN: -finalize-memref-to-llvm \ -// RUN: -convert-func-to-llvm \ -// RUN: -canonicalize \ -// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80 features=+ptx76}))' \ -// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \ -// RUN: -gpu-to-llvm \ -// RUN: -convert-func-to-llvm \ -// RUN: -reconcile-unrealized-casts \ +// RUN: -test-lower-to-nvvm="kernel-index-bitwidth=32 cubin-chip=sm_80 cubin-features=+ptx76" \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir --- a/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f32.mlir @@ -11,22 +11,7 @@ // RUN: mlir-opt %s \ // RUN: -test-transform-dialect-interpreter \ // RUN: -test-transform-dialect-erase-schedule \ -// RUN: -gpu-kernel-outlining \ -// RUN: -convert-scf-to-cf \ -// RUN: -convert-vector-to-llvm \ -// RUN: -convert-math-to-llvm \ -// RUN: -expand-strided-metadata \ -// RUN: -lower-affine \ -// RUN: -convert-index-to-llvm=index-bitwidth=32 \ -// RUN: -convert-arith-to-llvm \ -// RUN: -finalize-memref-to-llvm \ -// RUN: -convert-func-to-llvm \ -// RUN: -canonicalize \ -// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm{use-opaque-pointers=1},lower-affine,convert-scf-to-cf,convert-vector-to-llvm,convert-math-to-llvm,expand-strided-metadata,lower-affine,convert-index-to-llvm{index-bitwidth=32},convert-arith-to-llvm,reconcile-unrealized-casts,gpu-to-cubin{chip=sm_80 features=+ptx76}))' \ -// RUN: | mlir-opt -convert-index-to-llvm=index-bitwidth=32 \ -// RUN: -gpu-to-llvm \ -// RUN: -convert-func-to-llvm \ -// RUN: -reconcile-unrealized-casts \ +// RUN: -test-lower-to-nvvm="kernel-index-bitwidth=32 cubin-chip=sm_80 cubin-features=+ptx76" \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%mlir_cuda_runtime \ // RUN: --shared-libs=%mlir_runner_utils \ diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt --- a/mlir/test/lib/Dialect/CMakeLists.txt +++ b/mlir/test/lib/Dialect/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(Math) add_subdirectory(MemRef) add_subdirectory(NVGPU) +add_subdirectory(NVVM) add_subdirectory(SCF) add_subdirectory(Shape) add_subdirectory(SPIRV) diff --git a/mlir/test/lib/Dialect/NVVM/CMakeLists.txt b/mlir/test/lib/Dialect/NVVM/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/NVVM/CMakeLists.txt @@ -0,0 +1,54 @@ +if (MLIR_ENABLE_CUDA_RUNNER) + # Configure CUDA support. Using check_language first allows us to give a + # custom error message. + include(CheckLanguage) + check_language(CUDA) + if (CMAKE_CUDA_COMPILER) + enable_language(CUDA) + else() + message(SEND_ERROR + "Building the mlir cuda runner requires a working CUDA install") + endif() + + # We need the libcuda.so library. + find_library(CUDA_RUNTIME_LIBRARY cuda HINTS ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES} REQUIRED) + + get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) + set(LIBS + ${conversion_libs} + + MLIRAnalysis + MLIRArithDialect + MLIRBuiltinToLLVMIRTranslation + MLIRExecutionEngine + MLIRFuncDialect + MLIRGPUDialect + MLIRIR + MLIRJitRunner + MLIRLLVMDialect + MLIRLLVMCommonConversion + MLIRLLVMToLLVMIRTranslation + MLIRToLLVMIRTranslationRegistration + MLIRMemRefDialect + MLIRMemRefToLLVM + MLIRParser + MLIRSPIRVDialect + MLIRSPIRVTransforms + MLIRSupport + MLIRTargetLLVMIRExport + MLIRTransforms + MLIRTranslateLib + MLIRVectorDialect + MLIRVectorToLLVM + ) + + # Exclude tests from libMLIR.so + add_mlir_library(MLIRNVVMTestPasses + TestLowerToNVVM.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + ${LIBS} + ) +endif() diff --git a/mlir/test/lib/Dialect/NVVM/TestLowerToNVVM.cpp b/mlir/test/lib/Dialect/NVVM/TestLowerToNVVM.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/NVVM/TestLowerToNVVM.cpp @@ -0,0 +1,330 @@ +//===- TestLowerToNVVM.cpp - Test lowering to NVVM as a sink pass ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass for testing the lowering to NVVM as a generally +// usable sink pass. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" +#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" +#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/ExecutionEngine/JitRunner.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassOptions.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/TargetSelect.h" + +using namespace mlir; + +namespace { +struct TestLowerToNVVMOptions + : public PassPipelineOptions { + PassOptions::Option hostIndexBitWidth{ + *this, "host-index-bitwidth", + llvm::cl::desc("Bitwidth of the index type for the host (warning this " + "should be 64 until the GPU layering is fixed)"), + llvm::cl::init(64)}; + PassOptions::Option hostUseBarePtrCallConv{ + *this, "host-bare-ptr-calling-convention", + llvm::cl::desc( + "Whether to use the bareptr calling convention on the host (warning " + "this should be false until the GPU layering is fixed)"), + llvm::cl::init(false)}; + PassOptions::Option kernelIndexBitWidth{ + *this, "kernel-index-bitwidth", + llvm::cl::desc("Bitwidth of the index type for the GPU kernels"), + llvm::cl::init(64)}; + PassOptions::Option kernelUseBarePtrCallConv{ + *this, "kernel-bare-ptr-calling-convention", + llvm::cl::desc( + "Whether to use the bareptr calling convention on the kernel " + "(warning this should be false until the GPU layering is fixed)"), + llvm::cl::init(false)}; + PassOptions::Option cubinTriple{ + *this, "cubin-triple", + llvm::cl::desc("Triple to use to serialize to cubin."), + llvm::cl::init("nvptx64-nvidia-cuda")}; + PassOptions::Option cubinChip{ + *this, "cubin-chip", llvm::cl::desc("Chip to use to serialize to cubin."), + llvm::cl::init("sm_80")}; + PassOptions::Option cubinFeatures{ + *this, "cubin-features", + llvm::cl::desc("Features to use to serialize to cubin."), + llvm::cl::init("+ptx76")}; +}; + +//===----------------------------------------------------------------------===// +// GPUModule-specific stuff. +//===----------------------------------------------------------------------===// +void buildGpuPassPipeline(OpPassManager &pm, + const TestLowerToNVVMOptions &options) { + pm.addNestedPass(createStripDebugInfoPass()); + + pm.addNestedPass(createConvertVectorToSCFPass()); + // Blanket-convert any remaining linalg ops to loops if any remain. + pm.addNestedPass(createConvertLinalgToLoopsPass()); + // Convert SCF to CF (always needed). + pm.addNestedPass(createConvertSCFToCFPass()); + // Convert Math to LLVM (always needed). + pm.addNestedPass(createConvertMathToLLVMPass()); + // Expand complicated MemRef operations before lowering them. + pm.addNestedPass(memref::createExpandStridedMetadataPass()); + // The expansion may create affine expressions. Get rid of them. + pm.addNestedPass(createLowerAffinePass()); + + // Convert MemRef to LLVM (always needed). + // TODO: C++20 designated initializers. + FinalizeMemRefToLLVMConversionPassOptions + finalizeMemRefToLLVMConversionPassOptions; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + finalizeMemRefToLLVMConversionPassOptions.indexBitwidth = + options.kernelIndexBitWidth; + finalizeMemRefToLLVMConversionPassOptions.useOpaquePointers = true; + pm.addNestedPass(createFinalizeMemRefToLLVMConversionPass( + finalizeMemRefToLLVMConversionPassOptions)); + + // Convert Func to LLVM (always needed). + // TODO: C++20 designated initializers. + ConvertFuncToLLVMPassOptions convertFuncToLLVMPassOptions; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + convertFuncToLLVMPassOptions.indexBitwidth = options.kernelIndexBitWidth; + convertFuncToLLVMPassOptions.useBarePtrCallConv = + options.kernelUseBarePtrCallConv; + convertFuncToLLVMPassOptions.useOpaquePointers = true; + pm.addNestedPass( + createConvertFuncToLLVMPass(convertFuncToLLVMPassOptions)); + + // TODO: C++20 designated initializers. + ConvertIndexToLLVMPassOptions convertIndexToLLVMPassOpt; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + convertIndexToLLVMPassOpt.indexBitwidth = options.kernelIndexBitWidth; + pm.addNestedPass( + createConvertIndexToLLVMPass(convertIndexToLLVMPassOpt)); + + // TODO: C++20 designated initializers. + // The following pass is inconsistent. + // ConvertGpuOpsToNVVMOpsOptions convertGpuOpsToNVVMOpsOptions; + // convertGpuOpsToNVVMOpsOptions.indexBitwidth = + // options.kernelIndexBitWidth; + pm.addNestedPass( + // TODO: fix inconsistence. + createLowerGpuOpsToNVVMOpsPass(/*indexBitWidth=*/ + options.kernelIndexBitWidth)); + + // TODO: C++20 designated initializers. + ConvertNVGPUToNVVMPassOptions convertNVGPUToNVVMPassOptions; + convertNVGPUToNVVMPassOptions.useOpaquePointers = true; + pm.addNestedPass( + createConvertNVGPUToNVVMPass(convertNVGPUToNVVMPassOptions)); + pm.addNestedPass(createConvertSCFToCFPass()); + + // TODO: C++20 designated initializers. + GpuToLLVMConversionPassOptions gpuToLLVMConversionOptions; + // Note: hostBarePtrCallConv must be false for now otherwise + // gpu::HostRegister is ill-defined: it wants unranked memrefs but can't + // lower the to bare ptr. + gpuToLLVMConversionOptions.hostBarePtrCallConv = + options.hostUseBarePtrCallConv; + gpuToLLVMConversionOptions.kernelBarePtrCallConv = + options.kernelUseBarePtrCallConv; + gpuToLLVMConversionOptions.useOpaquePointers = true; + + // TODO: something useful here. + // gpuToLLVMConversionOptions.gpuBinaryAnnotation = ""; + pm.addNestedPass( + createGpuToLLVMConversionPass(gpuToLLVMConversionOptions)); + + // Convert vector to LLVM (always needed). + // TODO: C++20 designated initializers. + ConvertVectorToLLVMPassOptions convertVectorToLLVMPassOptions; + convertVectorToLLVMPassOptions.reassociateFPReductions = true; + pm.addNestedPass( + createConvertVectorToLLVMPass(convertVectorToLLVMPassOptions)); + + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + // Finally we can reconcile unrealized casts. + pm.addNestedPass(createReconcileUnrealizedCastsPass()); + pm.addNestedPass(createGpuSerializeToCubinPass( + options.cubinTriple, options.cubinChip, options.cubinFeatures)); +} + +void buildLowerToNVVMPassPipeline(OpPassManager &pm, + const TestLowerToNVVMOptions &options) { + //===----------------------------------------------------------------------===// + // Host-specific stuff. + //===----------------------------------------------------------------------===// + // Important, must be run at the top-level. + pm.addPass(createGpuKernelOutliningPass()); + + // Important, all host passes must be run at the func level so that host + // conversions can remain with 64 bit indices without polluting the GPU + // kernel that may have 32 bit indices. + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + pm.addNestedPass(createConvertVectorToSCFPass()); + // Blanket-convert any remaining linalg ops to loops if any remain. + pm.addNestedPass(createConvertLinalgToLoopsPass()); + // Convert SCF to CF (always needed). + pm.addNestedPass(createConvertSCFToCFPass()); + // Convert Math to LLVM (always needed). + pm.addNestedPass(createConvertMathToLLVMPass()); + // Expand complicated MemRef operations before lowering them. + pm.addNestedPass(memref::createExpandStridedMetadataPass()); + // The expansion may create affine expressions. Get rid of them. + pm.addNestedPass(createLowerAffinePass()); + + // Convert MemRef to LLVM (always needed). + // TODO: C++20 designated initializers. + FinalizeMemRefToLLVMConversionPassOptions + finalizeMemRefToLLVMConversionPassOptions; + finalizeMemRefToLLVMConversionPassOptions.useAlignedAlloc = true; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + finalizeMemRefToLLVMConversionPassOptions.indexBitwidth = + options.hostIndexBitWidth; + finalizeMemRefToLLVMConversionPassOptions.useOpaquePointers = true; + pm.addNestedPass(createFinalizeMemRefToLLVMConversionPass( + finalizeMemRefToLLVMConversionPassOptions)); + + // Convert Func to LLVM (always needed). + // TODO: C++20 designated initializers. + ConvertFuncToLLVMPassOptions convertFuncToLLVMPassOptions; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + convertFuncToLLVMPassOptions.indexBitwidth = options.hostIndexBitWidth; + convertFuncToLLVMPassOptions.useBarePtrCallConv = + options.hostUseBarePtrCallConv; + convertFuncToLLVMPassOptions.useOpaquePointers = true; + pm.addNestedPass( + createConvertFuncToLLVMPass(convertFuncToLLVMPassOptions)); + + // TODO: C++20 designated initializers. + ConvertIndexToLLVMPassOptions convertIndexToLLVMPassOpt; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + convertIndexToLLVMPassOpt.indexBitwidth = options.hostIndexBitWidth; + pm.addNestedPass( + createConvertIndexToLLVMPass(convertIndexToLLVMPassOpt)); + + pm.addNestedPass(createArithToLLVMConversionPass()); + + // Sprinkle some cleanups. + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + + //===----------------------------------------------------------------------===// + // GPUModule-specific stuff. + //===----------------------------------------------------------------------===// + buildGpuPassPipeline(pm, options); + + //===----------------------------------------------------------------------===// + // Host post-GPUModule-specific stuff. + //===----------------------------------------------------------------------===// + // Convert vector to LLVM (always needed). + // TODO: C++20 designated initializers. + ConvertVectorToLLVMPassOptions convertVectorToLLVMPassOptions; + convertVectorToLLVMPassOptions.reassociateFPReductions = true; + pm.addNestedPass( + createConvertVectorToLLVMPass(convertVectorToLLVMPassOptions)); + + ConvertIndexToLLVMPassOptions convertIndexToLLVMPassOpt3; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + // TODO: fix GPU layering. + convertIndexToLLVMPassOpt3.indexBitwidth = options.hostIndexBitWidth; + pm.addPass(createConvertIndexToLLVMPass(convertIndexToLLVMPassOpt3)); + + // This must happen after cubin translation otherwise gpu.launch_func is + // illegal if no cubin annotation is present. + // TODO: C++20 designated initializers. + GpuToLLVMConversionPassOptions gpuToLLVMConversionOptions; + // Note: hostBarePtrCallConv must be false for now otherwise + // gpu::HostRegister is ill-defined: it wants unranked memrefs but can't + // lower the to bare ptr. + gpuToLLVMConversionOptions.hostBarePtrCallConv = + options.hostUseBarePtrCallConv; + gpuToLLVMConversionOptions.kernelBarePtrCallConv = + options.kernelUseBarePtrCallConv; + gpuToLLVMConversionOptions.useOpaquePointers = true; + // TODO: something useful here. + // gpuToLLVMConversionOptions.gpuBinaryAnnotation = ""; + pm.addPass(createGpuToLLVMConversionPass(gpuToLLVMConversionOptions)); + + // Convert Func to LLVM (always needed). + // TODO: C++20 designated initializers. + ConvertFuncToLLVMPassOptions convertFuncToLLVMPassOptions2; + // Must be 64b on the host, things don't compose properly around + // gpu::LaunchOp and gpu::HostRegisterOp. + convertFuncToLLVMPassOptions2.indexBitwidth = options.hostIndexBitWidth; + convertFuncToLLVMPassOptions2.useBarePtrCallConv = + options.hostUseBarePtrCallConv; + convertFuncToLLVMPassOptions2.useOpaquePointers = true; + pm.addPass(createConvertFuncToLLVMPass(convertFuncToLLVMPassOptions2)); + + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + // Finally we can reconcile unrealized casts. + pm.addPass(createReconcileUnrealizedCastsPass()); +} +} // namespace + +namespace mlir { +namespace test { +void registerTestLowerToNVVM() { + PassPipelineRegistration( + "test-lower-to-nvvm", + "An example of pipeline to lower the main dialects (arith, linalg, " + "memref, scf, vector) down to NVVM.", + buildLowerToNVVMPassPipeline); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -12,7 +12,13 @@ ) if(MLIR_INCLUDE_TESTS) + if(MLIR_ENABLE_CUDA_RUNNER) + set(cuda_test_libs + MLIRNVVMTestPasses + ) + endif() set(test_libs + ${cuda_test_libs} MLIRTestFuncToLLVM MLIRAffineTransformsTestPasses MLIRArithTestPasses diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -109,6 +109,7 @@ void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); void registerTestLowerToLLVM(); +void registerTestLowerToNVVM(); void registerTestMakeIsolatedFromAbovePass(); void registerTestMatchReductionPass(); void registerTestMathAlgebraicSimplificationPass(); @@ -199,6 +200,7 @@ mlir::test::registerTestDialectConversionPasses(); #if MLIR_CUDA_CONVERSIONS_ENABLED mlir::test::registerTestGpuSerializeToCubinPass(); + mlir::test::registerTestLowerToNVVM(); #endif #if MLIR_ROCM_CONVERSIONS_ENABLED mlir::test::registerTestGpuSerializeToHsacoPass(); @@ -273,6 +275,6 @@ ::test::registerTestTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); #endif - return mlir::asMainReturnCode( - mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry)); + return mlir::asMainReturnCode(mlir::MlirOptMain( + argc, argv, "MLIR modular optimizer driver\n", registry)); }