diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -52,6 +52,12 @@ /// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP). std::unique_ptr> createGpuToLLVMConversionPass(); +/// Collect a set of patterns to convert from the GPU dialect to LLVM and +/// populate converter for gpu types. +void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns, + StringRef gpuBinaryAnnotation = {}); + } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_ diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -323,21 +323,7 @@ populateStdToLLVMConversionPatterns(converter, patterns); populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, target); - - converter.addConversion( - [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { - return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - }); - patterns.add(converter); - patterns.add(converter, - gpuBinaryAnnotation); - patterns.add(&converter.getContext()); + populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -804,3 +790,22 @@ mlir::createGpuToLLVMConversionPass() { return std::make_unique(); } + +void mlir::populateGpuToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + StringRef gpuBinaryAnnotation) { + converter.addConversion( + [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { + return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); + }); + patterns.add(converter); + patterns.add(converter, + gpuBinaryAnnotation); + patterns.add(&converter.getContext()); +}