diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -105,6 +105,15 @@ return result; } +static llvm::cl::opt + clSMVersion("sm", llvm::cl::desc("SM version to target"), + llvm::cl::init("sm_35")); + +static llvm::cl::opt + clIndexWidth("index-bitwidth", + llvm::cl::desc("Bitwidth of index type to use for lowering"), + llvm::cl::init(32)); + static LogicalResult runMLIRPasses(ModuleOp m) { PassManager pm(m.getContext()); applyPassManagerCLOptions(pm); @@ -113,10 +122,10 @@ pm.addPass(createGpuKernelOutliningPass()); auto &kernelPm = pm.nest(); kernelPm.addPass(createStripDebugInfoPass()); - kernelPm.addPass(createLowerGpuOpsToNVVMOpsPass()); + kernelPm.addPass(createLowerGpuOpsToNVVMOpsPass(clIndexWidth.getValue())); kernelPm.addPass(createConvertGPUKernelToBlobPass( translateModuleToNVVMIR, compilePtxToCubin, "nvptx64-nvidia-cuda", - "sm_35", "+ptx60", gpuBinaryAnnotation)); + clSMVersion.getValue(), "+ptx60", gpuBinaryAnnotation)); pm.addPass(createGpuToLLVMConversionPass(gpuBinaryAnnotation)); return pm.run(m);